[PATCH] rsh.c env and quoting cleanup, take 2
[git.git] / gitMergeCommon.py
1 import sys, re, os, traceback
2 from sets import Set
3
4 def die(*args):
5     printList(args, sys.stderr)
6     sys.exit(2)
7
8 def printList(list, file=sys.stdout):
9     for x in list:
10         file.write(str(x))
11         file.write(' ')
12     file.write('\n')
13
14 if sys.version_info[0] < 2 or \
15        (sys.version_info[0] == 2 and sys.version_info[1] < 4):
16     die('Python version 2.4 required, found', \
17         str(sys.version_info[0])+'.'+str(sys.version_info[1])+'.'+ \
18         str(sys.version_info[2]))
19
20 import subprocess
21
22 # Debugging machinery
23 # -------------------
24
25 DEBUG = 0
26 functionsToDebug = Set()
27
28 def addDebug(func):
29     if type(func) == str:
30         functionsToDebug.add(func)
31     else:
32         functionsToDebug.add(func.func_name)
33
34 def debug(*args):
35     if DEBUG:
36         funcName = traceback.extract_stack()[-2][2]
37         if funcName in functionsToDebug:
38             printList(args)
39
40 # Program execution
41 # -----------------
42
43 class ProgramError(Exception):
44     def __init__(self, progStr, error):
45         self.progStr = progStr
46         self.error = error
47
48     def __str__(self):
49         return self.progStr + ': ' + self.error
50
51 addDebug('runProgram')
52 def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
53     debug('runProgram prog:', str(prog), 'input:', str(input))
54     if type(prog) is str:
55         progStr = prog
56     else:
57         progStr = ' '.join(prog)
58     
59     try:
60         if pipeOutput:
61             stderr = subprocess.STDOUT
62             stdout = subprocess.PIPE
63         else:
64             stderr = None
65             stdout = None
66         pop = subprocess.Popen(prog,
67                                shell = type(prog) is str,
68                                stderr=stderr,
69                                stdout=stdout,
70                                stdin=subprocess.PIPE,
71                                env=env)
72     except OSError, e:
73         debug('strerror:', e.strerror)
74         raise ProgramError(progStr, e.strerror)
75
76     if input != None:
77         pop.stdin.write(input)
78     pop.stdin.close()
79
80     if pipeOutput:
81         out = pop.stdout.read()
82     else:
83         out = ''
84
85     code = pop.wait()
86     if returnCode:
87         ret = [out, code]
88     else:
89         ret = out
90     if code != 0 and not returnCode:
91         debug('error output:', out)
92         debug('prog:', prog)
93         raise ProgramError(progStr, out)
94 #    debug('output:', out.replace('\0', '\n'))
95     return ret
96
97 # Code for computing common ancestors
98 # -----------------------------------
99
100 currentId = 0
101 def getUniqueId():
102     global currentId
103     currentId += 1
104     return currentId
105
106 # The 'virtual' commit objects have SHAs which are integers
107 shaRE = re.compile('^[0-9a-f]{40}$')
108 def isSha(obj):
109     return (type(obj) is str and bool(shaRE.match(obj))) or \
110            (type(obj) is int and obj >= 1)
111
112 class Commit:
113     def __init__(self, sha, parents, tree=None):
114         self.parents = parents
115         self.firstLineMsg = None
116         self.children = []
117
118         if tree:
119             tree = tree.rstrip()
120             assert(isSha(tree))
121         self._tree = tree
122
123         if not sha:
124             self.sha = getUniqueId()
125             self.virtual = True
126             self.firstLineMsg = 'virtual commit'
127             assert(isSha(tree))
128         else:
129             self.virtual = False
130             self.sha = sha.rstrip()
131         assert(isSha(self.sha))
132
133     def tree(self):
134         self.getInfo()
135         assert(self._tree != None)
136         return self._tree
137
138     def shortInfo(self):
139         self.getInfo()
140         return str(self.sha) + ' ' + self.firstLineMsg
141
142     def __str__(self):
143         return self.shortInfo()
144
145     def getInfo(self):
146         if self.virtual or self.firstLineMsg != None:
147             return
148         else:
149             info = runProgram(['git-cat-file', 'commit', self.sha])
150             info = info.split('\n')
151             msg = False
152             for l in info:
153                 if msg:
154                     self.firstLineMsg = l
155                     break
156                 else:
157                     if l.startswith('tree'):
158                         self._tree = l[5:].rstrip()
159                     elif l == '':
160                         msg = True
161
162 class Graph:
163     def __init__(self):
164         self.commits = []
165         self.shaMap = {}
166
167     def addNode(self, node):
168         assert(isinstance(node, Commit))
169         self.shaMap[node.sha] = node
170         self.commits.append(node)
171         for p in node.parents:
172             p.children.append(node)
173         return node
174
175     def reachableNodes(self, n1, n2):
176         res = {}
177         def traverse(n):
178             res[n] = True
179             for p in n.parents:
180                 traverse(p)
181
182         traverse(n1)
183         traverse(n2)
184         return res
185
186     def fixParents(self, node):
187         for x in range(0, len(node.parents)):
188             node.parents[x] = self.shaMap[node.parents[x]]
189
190 # addDebug('buildGraph')
191 def buildGraph(heads):
192     debug('buildGraph heads:', heads)
193     for h in heads:
194         assert(isSha(h))
195
196     g = Graph()
197
198     out = runProgram(['git-rev-list', '--parents'] + heads)
199     for l in out.split('\n'):
200         if l == '':
201             continue
202         shas = l.split(' ')
203
204         # This is a hack, we temporarily use the 'parents' attribute
205         # to contain a list of SHA1:s. They are later replaced by proper
206         # Commit objects.
207         c = Commit(shas[0], shas[1:])
208
209         g.commits.append(c)
210         g.shaMap[c.sha] = c
211
212     for c in g.commits:
213         g.fixParents(c)
214
215     for c in g.commits:
216         for p in c.parents:
217             p.children.append(c)
218     return g
219
220 # Write the empty tree to the object database and return its SHA1
221 def writeEmptyTree():
222     tmpIndex = os.environ['GIT_DIR'] + '/merge-tmp-index'
223     def delTmpIndex():
224         try:
225             os.unlink(tmpIndex)
226         except OSError:
227             pass
228     delTmpIndex()
229     newEnv = os.environ.copy()
230     newEnv['GIT_INDEX_FILE'] = tmpIndex
231     res = runProgram(['git-write-tree'], env=newEnv).rstrip()
232     delTmpIndex()
233     return res
234
235 def addCommonRoot(graph):
236     roots = []
237     for c in graph.commits:
238         if len(c.parents) == 0:
239             roots.append(c)
240
241     superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
242     graph.addNode(superRoot)
243     for r in roots:
244         r.parents = [superRoot]
245     superRoot.children = roots
246     return superRoot
247
248 def getCommonAncestors(graph, commit1, commit2):
249     '''Find the common ancestors for commit1 and commit2'''
250     assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
251
252     def traverse(start, set):
253         stack = [start]
254         while len(stack) > 0:
255             el = stack.pop()
256             set.add(el)
257             for p in el.parents:
258                 if p not in set:
259                     stack.append(p)
260     h1Set = Set()
261     h2Set = Set()
262     traverse(commit1, h1Set)
263     traverse(commit2, h2Set)
264     shared = h1Set.intersection(h2Set)
265
266     if len(shared) == 0:
267         shared = [addCommonRoot(graph)]
268         
269     res = Set()
270
271     for s in shared:
272         if len([c for c in s.children if c in shared]) == 0:
273             res.add(s)
274     return list(res)