gitMergeCommon.pyon commit nicer eye candies for pack-objects (b2504a0)
   1#
   2# Copyright (C) 2005 Fredrik Kuivinen
   3#
   4
   5import sys, re, os, traceback
   6from sets import Set
   7
   8def die(*args):
   9    printList(args, sys.stderr)
  10    sys.exit(2)
  11
  12def printList(list, file=sys.stdout):
  13    for x in list:
  14        file.write(str(x))
  15        file.write(' ')
  16    file.write('\n')
  17
  18import subprocess
  19
  20# Debugging machinery
  21# -------------------
  22
  23DEBUG = 0
  24functionsToDebug = Set()
  25
  26def addDebug(func):
  27    if type(func) == str:
  28        functionsToDebug.add(func)
  29    else:
  30        functionsToDebug.add(func.func_name)
  31
  32def debug(*args):
  33    if DEBUG:
  34        funcName = traceback.extract_stack()[-2][2]
  35        if funcName in functionsToDebug:
  36            printList(args)
  37
  38# Program execution
  39# -----------------
  40
  41class ProgramError(Exception):
  42    def __init__(self, progStr, error):
  43        self.progStr = progStr
  44        self.error = error
  45
  46    def __str__(self):
  47        return self.progStr + ': ' + self.error
  48
  49addDebug('runProgram')
  50def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
  51    debug('runProgram prog:', str(prog), 'input:', str(input))
  52    if type(prog) is str:
  53        progStr = prog
  54    else:
  55        progStr = ' '.join(prog)
  56    
  57    try:
  58        if pipeOutput:
  59            stderr = subprocess.STDOUT
  60            stdout = subprocess.PIPE
  61        else:
  62            stderr = None
  63            stdout = None
  64        pop = subprocess.Popen(prog,
  65                               shell = type(prog) is str,
  66                               stderr=stderr,
  67                               stdout=stdout,
  68                               stdin=subprocess.PIPE,
  69                               env=env)
  70    except OSError, e:
  71        debug('strerror:', e.strerror)
  72        raise ProgramError(progStr, e.strerror)
  73
  74    if input != None:
  75        pop.stdin.write(input)
  76    pop.stdin.close()
  77
  78    if pipeOutput:
  79        out = pop.stdout.read()
  80    else:
  81        out = ''
  82
  83    code = pop.wait()
  84    if returnCode:
  85        ret = [out, code]
  86    else:
  87        ret = out
  88    if code != 0 and not returnCode:
  89        debug('error output:', out)
  90        debug('prog:', prog)
  91        raise ProgramError(progStr, out)
  92#    debug('output:', out.replace('\0', '\n'))
  93    return ret
  94
  95# Code for computing common ancestors
  96# -----------------------------------
  97
  98currentId = 0
  99def getUniqueId():
 100    global currentId
 101    currentId += 1
 102    return currentId
 103
 104# The 'virtual' commit objects have SHAs which are integers
 105shaRE = re.compile('^[0-9a-f]{40}$')
 106def isSha(obj):
 107    return (type(obj) is str and bool(shaRE.match(obj))) or \
 108           (type(obj) is int and obj >= 1)
 109
 110class Commit(object):
 111    __slots__ = ['parents', 'firstLineMsg', 'children', '_tree', 'sha',
 112                 'virtual']
 113
 114    def __init__(self, sha, parents, tree=None):
 115        self.parents = parents
 116        self.firstLineMsg = None
 117        self.children = []
 118
 119        if tree:
 120            tree = tree.rstrip()
 121            assert(isSha(tree))
 122        self._tree = tree
 123
 124        if not sha:
 125            self.sha = getUniqueId()
 126            self.virtual = True
 127            self.firstLineMsg = 'virtual commit'
 128            assert(isSha(tree))
 129        else:
 130            self.virtual = False
 131            self.sha = sha.rstrip()
 132        assert(isSha(self.sha))
 133
 134    def tree(self):
 135        self.getInfo()
 136        assert(self._tree != None)
 137        return self._tree
 138
 139    def shortInfo(self):
 140        self.getInfo()
 141        return str(self.sha) + ' ' + self.firstLineMsg
 142
 143    def __str__(self):
 144        return self.shortInfo()
 145
 146    def getInfo(self):
 147        if self.virtual or self.firstLineMsg != None:
 148            return
 149        else:
 150            info = runProgram(['git-cat-file', 'commit', self.sha])
 151            info = info.split('\n')
 152            msg = False
 153            for l in info:
 154                if msg:
 155                    self.firstLineMsg = l
 156                    break
 157                else:
 158                    if l.startswith('tree'):
 159                        self._tree = l[5:].rstrip()
 160                    elif l == '':
 161                        msg = True
 162
 163class Graph:
 164    def __init__(self):
 165        self.commits = []
 166        self.shaMap = {}
 167
 168    def addNode(self, node):
 169        assert(isinstance(node, Commit))
 170        self.shaMap[node.sha] = node
 171        self.commits.append(node)
 172        for p in node.parents:
 173            p.children.append(node)
 174        return node
 175
 176    def reachableNodes(self, n1, n2):
 177        res = {}
 178        def traverse(n):
 179            res[n] = True
 180            for p in n.parents:
 181                traverse(p)
 182
 183        traverse(n1)
 184        traverse(n2)
 185        return res
 186
 187    def fixParents(self, node):
 188        for x in range(0, len(node.parents)):
 189            node.parents[x] = self.shaMap[node.parents[x]]
 190
 191# addDebug('buildGraph')
 192def buildGraph(heads):
 193    debug('buildGraph heads:', heads)
 194    for h in heads:
 195        assert(isSha(h))
 196
 197    g = Graph()
 198
 199    out = runProgram(['git-rev-list', '--parents'] + heads)
 200    for l in out.split('\n'):
 201        if l == '':
 202            continue
 203        shas = l.split(' ')
 204
 205        # This is a hack, we temporarily use the 'parents' attribute
 206        # to contain a list of SHA1:s. They are later replaced by proper
 207        # Commit objects.
 208        c = Commit(shas[0], shas[1:])
 209
 210        g.commits.append(c)
 211        g.shaMap[c.sha] = c
 212
 213    for c in g.commits:
 214        g.fixParents(c)
 215
 216    for c in g.commits:
 217        for p in c.parents:
 218            p.children.append(c)
 219    return g
 220
 221# Write the empty tree to the object database and return its SHA1
 222def writeEmptyTree():
 223    tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index'
 224    def delTmpIndex():
 225        try:
 226            os.unlink(tmpIndex)
 227        except OSError:
 228            pass
 229    delTmpIndex()
 230    newEnv = os.environ.copy()
 231    newEnv['GIT_INDEX_FILE'] = tmpIndex
 232    res = runProgram(['git-write-tree'], env=newEnv).rstrip()
 233    delTmpIndex()
 234    return res
 235
 236def addCommonRoot(graph):
 237    roots = []
 238    for c in graph.commits:
 239        if len(c.parents) == 0:
 240            roots.append(c)
 241
 242    superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
 243    graph.addNode(superRoot)
 244    for r in roots:
 245        r.parents = [superRoot]
 246    superRoot.children = roots
 247    return superRoot
 248
 249def getCommonAncestors(graph, commit1, commit2):
 250    '''Find the common ancestors for commit1 and commit2'''
 251    assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
 252
 253    def traverse(start, set):
 254        stack = [start]
 255        while len(stack) > 0:
 256            el = stack.pop()
 257            set.add(el)
 258            for p in el.parents:
 259                if p not in set:
 260                    stack.append(p)
 261    h1Set = Set()
 262    h2Set = Set()
 263    traverse(commit1, h1Set)
 264    traverse(commit2, h2Set)
 265    shared = h1Set.intersection(h2Set)
 266
 267    if len(shared) == 0:
 268        shared = [addCommonRoot(graph)]
 269        
 270    res = Set()
 271
 272    for s in shared:
 273        if len([c for c in s.children if c in shared]) == 0:
 274            res.add(s)
 275    return list(res)