gitMergeCommon.pyon commit Merge branch 'master' of . (3db6b22)
   1import sys, re, os, traceback
   2from sets import Set
   3
   4if sys.version_info[0] < 2 or \
   5       (sys.version_info[0] == 2 and sys.version_info[1] < 4):
   6    print 'Python version 2.4 required, found', \
   7          str(sys.version_info[0])+'.'+str(sys.version_info[1])+'.'+ \
   8          str(sys.version_info[2])
   9    sys.exit(1)
  10
  11import subprocess
  12
  13def die(*args):
  14    printList(args, sys.stderr)
  15    sys.exit(2)
  16
  17# Debugging machinery
  18# -------------------
  19
  20DEBUG = 0
  21functionsToDebug = Set()
  22
  23def addDebug(func):
  24    if type(func) == str:
  25        functionsToDebug.add(func)
  26    else:
  27        functionsToDebug.add(func.func_name)
  28
  29def debug(*args):
  30    if DEBUG:
  31        funcName = traceback.extract_stack()[-2][2]
  32        if funcName in functionsToDebug:
  33            printList(args)
  34
  35def printList(list, file=sys.stdout):
  36    for x in list:
  37        file.write(str(x))
  38        file.write(' ')
  39    file.write('\n')
  40
  41# Program execution
  42# -----------------
  43
  44class ProgramError(Exception):
  45    def __init__(self, progStr, error):
  46        self.progStr = progStr
  47        self.error = error
  48
  49    def __str__(self):
  50        return self.progStr + ': ' + self.error
  51
  52addDebug('runProgram')
  53def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
  54    debug('runProgram prog:', str(prog), 'input:', str(input))
  55    if type(prog) is str:
  56        progStr = prog
  57    else:
  58        progStr = ' '.join(prog)
  59    
  60    try:
  61        if pipeOutput:
  62            stderr = subprocess.STDOUT
  63            stdout = subprocess.PIPE
  64        else:
  65            stderr = None
  66            stdout = None
  67        pop = subprocess.Popen(prog,
  68                               shell = type(prog) is str,
  69                               stderr=stderr,
  70                               stdout=stdout,
  71                               stdin=subprocess.PIPE,
  72                               env=env)
  73    except OSError, e:
  74        debug('strerror:', e.strerror)
  75        raise ProgramError(progStr, e.strerror)
  76
  77    if input != None:
  78        pop.stdin.write(input)
  79    pop.stdin.close()
  80
  81    if pipeOutput:
  82        out = pop.stdout.read()
  83    else:
  84        out = ''
  85
  86    code = pop.wait()
  87    if returnCode:
  88        ret = [out, code]
  89    else:
  90        ret = out
  91    if code != 0 and not returnCode:
  92        debug('error output:', out)
  93        debug('prog:', prog)
  94        raise ProgramError(progStr, out)
  95#    debug('output:', out.replace('\0', '\n'))
  96    return ret
  97
  98# Code for computing common ancestors
  99# -----------------------------------
 100
 101currentId = 0
 102def getUniqueId():
 103    global currentId
 104    currentId += 1
 105    return currentId
 106
 107# The 'virtual' commit objects have SHAs which are integers
 108shaRE = re.compile('^[0-9a-f]{40}$')
 109def isSha(obj):
 110    return (type(obj) is str and bool(shaRE.match(obj))) or \
 111           (type(obj) is int and obj >= 1)
 112
 113class Commit:
 114    def __init__(self, sha, parents, tree=None):
 115        self.parents = parents
 116        self.firstLineMsg = None
 117        self.children = []
 118
 119        if tree:
 120            tree = tree.rstrip()
 121            assert(isSha(tree))
 122        self._tree = tree
 123
 124        if not sha:
 125            self.sha = getUniqueId()
 126            self.virtual = True
 127            self.firstLineMsg = 'virtual commit'
 128            assert(isSha(tree))
 129        else:
 130            self.virtual = False
 131            self.sha = sha.rstrip()
 132        assert(isSha(self.sha))
 133
 134    def tree(self):
 135        self.getInfo()
 136        assert(self._tree != None)
 137        return self._tree
 138
 139    def shortInfo(self):
 140        self.getInfo()
 141        return str(self.sha) + ' ' + self.firstLineMsg
 142
 143    def __str__(self):
 144        return self.shortInfo()
 145
 146    def getInfo(self):
 147        if self.virtual or self.firstLineMsg != None:
 148            return
 149        else:
 150            info = runProgram(['git-cat-file', 'commit', self.sha])
 151            info = info.split('\n')
 152            msg = False
 153            for l in info:
 154                if msg:
 155                    self.firstLineMsg = l
 156                    break
 157                else:
 158                    if l.startswith('tree'):
 159                        self._tree = l[5:].rstrip()
 160                    elif l == '':
 161                        msg = True
 162
 163class Graph:
 164    def __init__(self):
 165        self.commits = []
 166        self.shaMap = {}
 167
 168    def addNode(self, node):
 169        assert(isinstance(node, Commit))
 170        self.shaMap[node.sha] = node
 171        self.commits.append(node)
 172        for p in node.parents:
 173            p.children.append(node)
 174        return node
 175
 176    def reachableNodes(self, n1, n2):
 177        res = {}
 178        def traverse(n):
 179            res[n] = True
 180            for p in n.parents:
 181                traverse(p)
 182
 183        traverse(n1)
 184        traverse(n2)
 185        return res
 186
 187    def fixParents(self, node):
 188        for x in range(0, len(node.parents)):
 189            node.parents[x] = self.shaMap[node.parents[x]]
 190
 191# addDebug('buildGraph')
 192def buildGraph(heads):
 193    debug('buildGraph heads:', heads)
 194    for h in heads:
 195        assert(isSha(h))
 196
 197    g = Graph()
 198
 199    out = runProgram(['git-rev-list', '--parents'] + heads)
 200    for l in out.split('\n'):
 201        if l == '':
 202            continue
 203        shas = l.split(' ')
 204
 205        # This is a hack, we temporarily use the 'parents' attribute
 206        # to contain a list of SHA1:s. They are later replaced by proper
 207        # Commit objects.
 208        c = Commit(shas[0], shas[1:])
 209
 210        g.commits.append(c)
 211        g.shaMap[c.sha] = c
 212
 213    for c in g.commits:
 214        g.fixParents(c)
 215
 216    for c in g.commits:
 217        for p in c.parents:
 218            p.children.append(c)
 219    return g
 220
 221# Write the empty tree to the object database and return its SHA1
 222def writeEmptyTree():
 223    tmpIndex = os.environ['GIT_DIR'] + '/merge-tmp-index'
 224    def delTmpIndex():
 225        try:
 226            os.unlink(tmpIndex)
 227        except OSError:
 228            pass
 229    delTmpIndex()
 230    newEnv = os.environ.copy()
 231    newEnv['GIT_INDEX_FILE'] = tmpIndex
 232    res = runProgram(['git-write-tree'], env=newEnv).rstrip()
 233    delTmpIndex()
 234    return res
 235
 236def addCommonRoot(graph):
 237    roots = []
 238    for c in graph.commits:
 239        if len(c.parents) == 0:
 240            roots.append(c)
 241
 242    superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
 243    graph.addNode(superRoot)
 244    for r in roots:
 245        r.parents = [superRoot]
 246    superRoot.children = roots
 247    return superRoot
 248
 249def getCommonAncestors(graph, commit1, commit2):
 250    '''Find the common ancestors for commit1 and commit2'''
 251    assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
 252
 253    def traverse(start, set):
 254        stack = [start]
 255        while len(stack) > 0:
 256            el = stack.pop()
 257            set.add(el)
 258            for p in el.parents:
 259                if p not in set:
 260                    stack.append(p)
 261    h1Set = Set()
 262    h2Set = Set()
 263    traverse(commit1, h1Set)
 264    traverse(commit2, h2Set)
 265    shared = h1Set.intersection(h2Set)
 266
 267    if len(shared) == 0:
 268        shared = [addCommonRoot(graph)]
 269        
 270    res = Set()
 271
 272    for s in shared:
 273        if len([c for c in s.children if c in shared]) == 0:
 274            res.add(s)
 275    return list(res)