blockimgdiff.py revision 623381880a32a2912f95949a6c406038ac4e7064
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import pprint
24import re
25import subprocess
26import sys
27import threading
28import tempfile
29
30from rangelib import *
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72class EmptyImage(object):
73  """A zero-length image."""
74  blocksize = 4096
75  care_map = RangeSet()
76  total_blocks = 0
77  file_map = {}
78  def ReadRangeSet(self, ranges):
79    return ()
80  def TotalSha1(self):
81    return sha1().hexdigest()
82
83
84class DataImage(object):
85  """An image wrapped around a single string of data."""
86
87  def __init__(self, data, trim=False, pad=False):
88    self.data = data
89    self.blocksize = 4096
90
91    assert not (trim and pad)
92
93    partial = len(self.data) % self.blocksize
94    if partial > 0:
95      if trim:
96        self.data = self.data[:-partial]
97      elif pad:
98        self.data += '\0' * (self.blocksize - partial)
99      else:
100        raise ValueError(("data for DataImage must be multiple of %d bytes "
101                          "unless trim or pad is specified") %
102                         (self.blocksize,))
103
104    assert len(self.data) % self.blocksize == 0
105
106    self.total_blocks = len(self.data) / self.blocksize
107    self.care_map = RangeSet(data=(0, self.total_blocks))
108
109    zero_blocks = []
110    nonzero_blocks = []
111    reference = '\0' * self.blocksize
112
113    for i in range(self.total_blocks):
114      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
115      if d == reference:
116        zero_blocks.append(i)
117        zero_blocks.append(i+1)
118      else:
119        nonzero_blocks.append(i)
120        nonzero_blocks.append(i+1)
121
122    self.file_map = {"__ZERO": RangeSet(zero_blocks),
123                     "__NONZERO": RangeSet(nonzero_blocks)}
124
125  def ReadRangeSet(self, ranges):
126    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
127
128  def TotalSha1(self):
129    if not hasattr(self, "sha1"):
130      self.sha1 = sha1(self.data).hexdigest()
131    return self.sha1
132
133
134class Transfer(object):
135  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
136    self.tgt_name = tgt_name
137    self.src_name = src_name
138    self.tgt_ranges = tgt_ranges
139    self.src_ranges = src_ranges
140    self.style = style
141    self.intact = (getattr(tgt_ranges, "monotonic", False) and
142                   getattr(src_ranges, "monotonic", False))
143    self.goes_before = {}
144    self.goes_after = {}
145
146    self.stash_before = []
147    self.use_stash = []
148
149    self.id = len(by_id)
150    by_id.append(self)
151
152  def NetStashChange(self):
153    return (sum(sr.size() for (_, sr) in self.stash_before) -
154            sum(sr.size() for (_, sr) in self.use_stash))
155
156  def __str__(self):
157    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
158            " to " + str(self.tgt_ranges) + ">")
159
160
161# BlockImageDiff works on two image objects.  An image object is
162# anything that provides the following attributes:
163#
164#    blocksize: the size in bytes of a block, currently must be 4096.
165#
166#    total_blocks: the total size of the partition/image, in blocks.
167#
168#    care_map: a RangeSet containing which blocks (in the range [0,
169#      total_blocks) we actually care about; i.e. which blocks contain
170#      data.
171#
172#    file_map: a dict that partitions the blocks contained in care_map
173#      into smaller domains that are useful for doing diffs on.
174#      (Typically a domain is a file, and the key in file_map is the
175#      pathname.)
176#
177#    ReadRangeSet(): a function that takes a RangeSet and returns the
178#      data contained in the image blocks of that RangeSet.  The data
179#      is returned as a list or tuple of strings; concatenating the
180#      elements together should produce the requested data.
181#      Implementations are free to break up the data into list/tuple
182#      elements in any way that is convenient.
183#
184#    TotalSha1(): a function that returns (as a hex string) the SHA-1
185#      hash of all the data in the image (ie, all the blocks in the
186#      care_map)
187#
188# When creating a BlockImageDiff, the src image may be None, in which
189# case the list of transfers produced will never read from the
190# original image.
191
192class BlockImageDiff(object):
193  def __init__(self, tgt, src=None, threads=None, version=2):
194    if threads is None:
195      threads = multiprocessing.cpu_count() // 2
196      if threads == 0: threads = 1
197    self.threads = threads
198    self.version = version
199
200    assert version in (1, 2)
201
202    self.tgt = tgt
203    if src is None:
204      src = EmptyImage()
205    self.src = src
206
207    # The updater code that installs the patch always uses 4k blocks.
208    assert tgt.blocksize == 4096
209    assert src.blocksize == 4096
210
211    # The range sets in each filemap should comprise a partition of
212    # the care map.
213    self.AssertPartition(src.care_map, src.file_map.values())
214    self.AssertPartition(tgt.care_map, tgt.file_map.values())
215
216  def Compute(self, prefix):
217    # When looking for a source file to use as the diff input for a
218    # target file, we try:
219    #   1) an exact path match if available, otherwise
220    #   2) a exact basename match if available, otherwise
221    #   3) a basename match after all runs of digits are replaced by
222    #      "#" if available, otherwise
223    #   4) we have no source for this target.
224    self.AbbreviateSourceNames()
225    self.FindTransfers()
226
227    # Find the ordering dependencies among transfers (this is O(n^2)
228    # in the number of transfers).
229    self.GenerateDigraph()
230    # Find a sequence of transfers that satisfies as many ordering
231    # dependencies as possible (heuristically).
232    self.FindVertexSequence()
233    # Fix up the ordering dependencies that the sequence didn't
234    # satisfy.
235    if self.version == 1:
236      self.RemoveBackwardEdges()
237    else:
238      self.ReverseBackwardEdges()
239      self.ImproveVertexSequence()
240
241    # Double-check our work.
242    self.AssertSequenceGood()
243
244    self.ComputePatches(prefix)
245    self.WriteTransfers(prefix)
246
247  def WriteTransfers(self, prefix):
248    out = []
249
250    out.append("%d\n" % (self.version,))   # format version number
251    total = 0
252    performs_read = False
253
254    stashes = {}
255    stashed_blocks = 0
256    max_stashed_blocks = 0
257
258    free_stash_ids = []
259    next_stash_id = 0
260
261    for xf in self.transfers:
262
263      if self.version < 2:
264        assert not xf.stash_before
265        assert not xf.use_stash
266
267      for s, sr in xf.stash_before:
268        assert s not in stashes
269        if free_stash_ids:
270          sid = heapq.heappop(free_stash_ids)
271        else:
272          sid = next_stash_id
273          next_stash_id += 1
274        stashes[s] = sid
275        stashed_blocks += sr.size()
276        out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
277
278      if stashed_blocks > max_stashed_blocks:
279        max_stashed_blocks = stashed_blocks
280
281      if self.version == 1:
282        src_string = xf.src_ranges.to_string_raw()
283      elif self.version == 2:
284
285        #   <# blocks> <src ranges>
286        #     OR
287        #   <# blocks> <src ranges> <src locs> <stash refs...>
288        #     OR
289        #   <# blocks> - <stash refs...>
290
291        size = xf.src_ranges.size()
292        src_string = [str(size)]
293
294        unstashed_src_ranges = xf.src_ranges
295        mapped_stashes = []
296        for s, sr in xf.use_stash:
297          sid = stashes.pop(s)
298          stashed_blocks -= sr.size()
299          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
300          sr = xf.src_ranges.map_within(sr)
301          mapped_stashes.append(sr)
302          src_string.append("%d:%s" % (sid, sr.to_string_raw()))
303          heapq.heappush(free_stash_ids, sid)
304
305        if unstashed_src_ranges:
306          src_string.insert(1, unstashed_src_ranges.to_string_raw())
307          if xf.use_stash:
308            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
309            src_string.insert(2, mapped_unstashed.to_string_raw())
310            mapped_stashes.append(mapped_unstashed)
311            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
312        else:
313          src_string.insert(1, "-")
314          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
315
316        src_string = " ".join(src_string)
317
318      # both versions:
319      #   zero <rangeset>
320      #   new <rangeset>
321      #   erase <rangeset>
322      #
323      # version 1:
324      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
325      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
326      #   move <src rangeset> <tgt rangeset>
327      #
328      # version 2:
329      #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
330      #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
331      #   move <tgt rangeset> <src_string>
332
333      tgt_size = xf.tgt_ranges.size()
334
335      if xf.style == "new":
336        assert xf.tgt_ranges
337        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
338        total += tgt_size
339      elif xf.style == "move":
340        performs_read = True
341        assert xf.tgt_ranges
342        assert xf.src_ranges.size() == tgt_size
343        if xf.src_ranges != xf.tgt_ranges:
344          if self.version == 1:
345            out.append("%s %s %s\n" % (
346                xf.style,
347                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
348          elif self.version == 2:
349            out.append("%s %s %s\n" % (
350                xf.style,
351                xf.tgt_ranges.to_string_raw(), src_string))
352          total += tgt_size
353      elif xf.style in ("bsdiff", "imgdiff"):
354        performs_read = True
355        assert xf.tgt_ranges
356        assert xf.src_ranges
357        if self.version == 1:
358          out.append("%s %d %d %s %s\n" % (
359              xf.style, xf.patch_start, xf.patch_len,
360              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
361        elif self.version == 2:
362          out.append("%s %d %d %s %s\n" % (
363              xf.style, xf.patch_start, xf.patch_len,
364              xf.tgt_ranges.to_string_raw(), src_string))
365        total += tgt_size
366      elif xf.style == "zero":
367        assert xf.tgt_ranges
368        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
369        if to_zero:
370          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
371          total += to_zero.size()
372      else:
373        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
374
375    out.insert(1, str(total) + "\n")
376    if self.version >= 2:
377      # version 2 only: after the total block count, we give the number
378      # of stash slots needed, and the maximum size needed (in blocks)
379      out.insert(2, str(next_stash_id) + "\n")
380      out.insert(3, str(max_stashed_blocks) + "\n")
381
382      # sanity check: abort if we're going to need more than 512 MB if
383      # stash space
384      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
385
386    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
387    if performs_read:
388      # if some of the original data is used, then at the end we'll
389      # erase all the blocks on the partition that don't contain data
390      # in the new image.
391      new_dontcare = all_tgt.subtract(self.tgt.care_map)
392      if new_dontcare:
393        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
394    else:
395      # if nothing is read (ie, this is a full OTA), then we can start
396      # by erasing the entire partition.
397      out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
398
399    with open(prefix + ".transfer.list", "wb") as f:
400      for i in out:
401        f.write(i)
402
403    if self.version >= 2:
404      print("max stashed blocks: %d  (%d bytes)\n" % (
405          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
406
407  def ComputePatches(self, prefix):
408    print("Reticulating splines...")
409    diff_q = []
410    patch_num = 0
411    with open(prefix + ".new.dat", "wb") as new_f:
412      for xf in self.transfers:
413        if xf.style == "zero":
414          pass
415        elif xf.style == "new":
416          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
417            new_f.write(piece)
418        elif xf.style == "diff":
419          src = self.src.ReadRangeSet(xf.src_ranges)
420          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
421
422          # We can't compare src and tgt directly because they may have
423          # the same content but be broken up into blocks differently, eg:
424          #
425          #    ["he", "llo"]  vs  ["h", "ello"]
426          #
427          # We want those to compare equal, ideally without having to
428          # actually concatenate the strings (these may be tens of
429          # megabytes).
430
431          src_sha1 = sha1()
432          for p in src:
433            src_sha1.update(p)
434          tgt_sha1 = sha1()
435          tgt_size = 0
436          for p in tgt:
437            tgt_sha1.update(p)
438            tgt_size += len(p)
439
440          if src_sha1.digest() == tgt_sha1.digest():
441            # These are identical; we don't need to generate a patch,
442            # just issue copy commands on the device.
443            xf.style = "move"
444          else:
445            # For files in zip format (eg, APKs, JARs, etc.) we would
446            # like to use imgdiff -z if possible (because it usually
447            # produces significantly smaller patches than bsdiff).
448            # This is permissible if:
449            #
450            #  - the source and target files are monotonic (ie, the
451            #    data is stored with blocks in increasing order), and
452            #  - we haven't removed any blocks from the source set.
453            #
454            # If these conditions are satisfied then appending all the
455            # blocks in the set together in order will produce a valid
456            # zip file (plus possibly extra zeros in the last block),
457            # which is what imgdiff needs to operate.  (imgdiff is
458            # fine with extra zeros at the end of the file.)
459            imgdiff = (xf.intact and
460                       xf.tgt_name.split(".")[-1].lower()
461                       in ("apk", "jar", "zip"))
462            xf.style = "imgdiff" if imgdiff else "bsdiff"
463            diff_q.append((tgt_size, src, tgt, xf, patch_num))
464            patch_num += 1
465
466        else:
467          assert False, "unknown style " + xf.style
468
469    if diff_q:
470      if self.threads > 1:
471        print("Computing patches (using %d threads)..." % (self.threads,))
472      else:
473        print("Computing patches...")
474      diff_q.sort()
475
476      patches = [None] * patch_num
477
478      lock = threading.Lock()
479      def diff_worker():
480        while True:
481          with lock:
482            if not diff_q: return
483            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
484          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
485          size = len(patch)
486          with lock:
487            patches[patchnum] = (patch, xf)
488            print("%10d %10d (%6.2f%%) %7s %s" % (
489                size, tgt_size, size * 100.0 / tgt_size, xf.style,
490                xf.tgt_name if xf.tgt_name == xf.src_name else (
491                    xf.tgt_name + " (from " + xf.src_name + ")")))
492
493      threads = [threading.Thread(target=diff_worker)
494                 for i in range(self.threads)]
495      for th in threads:
496        th.start()
497      while threads:
498        threads.pop().join()
499    else:
500      patches = []
501
502    p = 0
503    with open(prefix + ".patch.dat", "wb") as patch_f:
504      for patch, xf in patches:
505        xf.patch_start = p
506        xf.patch_len = len(patch)
507        patch_f.write(patch)
508        p += len(patch)
509
510  def AssertSequenceGood(self):
511    # Simulate the sequences of transfers we will output, and check that:
512    # - we never read a block after writing it, and
513    # - we write every block we care about exactly once.
514
515    # Start with no blocks having been touched yet.
516    touched = RangeSet()
517
518    # Imagine processing the transfers in order.
519    for xf in self.transfers:
520      # Check that the input blocks for this transfer haven't yet been touched.
521
522      x = xf.src_ranges
523      if self.version >= 2:
524        for _, sr in xf.use_stash:
525          x = x.subtract(sr)
526
527      assert not touched.overlaps(x)
528      # Check that the output blocks for this transfer haven't yet been touched.
529      assert not touched.overlaps(xf.tgt_ranges)
530      # Touch all the blocks written by this transfer.
531      touched = touched.union(xf.tgt_ranges)
532
533    # Check that we've written every target block.
534    assert touched == self.tgt.care_map
535
536  def ImproveVertexSequence(self):
537    print("Improving vertex order...")
538
539    # At this point our digraph is acyclic; we reversed any edges that
540    # were backwards in the heuristically-generated sequence.  The
541    # previously-generated order is still acceptable, but we hope to
542    # find a better order that needs less memory for stashed data.
543    # Now we do a topological sort to generate a new vertex order,
544    # using a greedy algorithm to choose which vertex goes next
545    # whenever we have a choice.
546
547    # Make a copy of the edge set; this copy will get destroyed by the
548    # algorithm.
549    for xf in self.transfers:
550      xf.incoming = xf.goes_after.copy()
551      xf.outgoing = xf.goes_before.copy()
552
553    L = []   # the new vertex order
554
555    # S is the set of sources in the remaining graph; we always choose
556    # the one that leaves the least amount of stashed data after it's
557    # executed.
558    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
559         if not u.incoming]
560    heapq.heapify(S)
561
562    while S:
563      _, _, xf = heapq.heappop(S)
564      L.append(xf)
565      for u in xf.outgoing:
566        del u.incoming[xf]
567        if not u.incoming:
568          heapq.heappush(S, (u.NetStashChange(), u.order, u))
569
570    # if this fails then our graph had a cycle.
571    assert len(L) == len(self.transfers)
572
573    self.transfers = L
574    for i, xf in enumerate(L):
575      xf.order = i
576
577  def RemoveBackwardEdges(self):
578    print("Removing backward edges...")
579    in_order = 0
580    out_of_order = 0
581    lost_source = 0
582
583    for xf in self.transfers:
584      lost = 0
585      size = xf.src_ranges.size()
586      for u in xf.goes_before:
587        # xf should go before u
588        if xf.order < u.order:
589          # it does, hurray!
590          in_order += 1
591        else:
592          # it doesn't, boo.  trim the blocks that u writes from xf's
593          # source, so that xf can go after u.
594          out_of_order += 1
595          assert xf.src_ranges.overlaps(u.tgt_ranges)
596          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
597          xf.intact = False
598
599      if xf.style == "diff" and not xf.src_ranges:
600        # nothing left to diff from; treat as new data
601        xf.style = "new"
602
603      lost = size - xf.src_ranges.size()
604      lost_source += lost
605
606    print(("  %d/%d dependencies (%.2f%%) were violated; "
607           "%d source blocks removed.") %
608          (out_of_order, in_order + out_of_order,
609           (out_of_order * 100.0 / (in_order + out_of_order))
610           if (in_order + out_of_order) else 0.0,
611           lost_source))
612
613  def ReverseBackwardEdges(self):
614    print("Reversing backward edges...")
615    in_order = 0
616    out_of_order = 0
617    stashes = 0
618    stash_size = 0
619
620    for xf in self.transfers:
621      lost = 0
622      size = xf.src_ranges.size()
623      for u in xf.goes_before.copy():
624        # xf should go before u
625        if xf.order < u.order:
626          # it does, hurray!
627          in_order += 1
628        else:
629          # it doesn't, boo.  modify u to stash the blocks that it
630          # writes that xf wants to read, and then require u to go
631          # before xf.
632          out_of_order += 1
633
634          overlap = xf.src_ranges.intersect(u.tgt_ranges)
635          assert overlap
636
637          u.stash_before.append((stashes, overlap))
638          xf.use_stash.append((stashes, overlap))
639          stashes += 1
640          stash_size += overlap.size()
641
642          # reverse the edge direction; now xf must go after u
643          del xf.goes_before[u]
644          del u.goes_after[xf]
645          xf.goes_after[u] = None    # value doesn't matter
646          u.goes_before[xf] = None
647
648    print(("  %d/%d dependencies (%.2f%%) were violated; "
649           "%d source blocks stashed.") %
650          (out_of_order, in_order + out_of_order,
651           (out_of_order * 100.0 / (in_order + out_of_order))
652           if (in_order + out_of_order) else 0.0,
653           stash_size))
654
655  def FindVertexSequence(self):
656    print("Finding vertex sequence...")
657
658    # This is based on "A Fast & Effective Heuristic for the Feedback
659    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
660    # it as starting with the digraph G and moving all the vertices to
661    # be on a horizontal line in some order, trying to minimize the
662    # number of edges that end up pointing to the left.  Left-pointing
663    # edges will get removed to turn the digraph into a DAG.  In this
664    # case each edge has a weight which is the number of source blocks
665    # we'll lose if that edge is removed; we try to minimize the total
666    # weight rather than just the number of edges.
667
668    # Make a copy of the edge set; this copy will get destroyed by the
669    # algorithm.
670    for xf in self.transfers:
671      xf.incoming = xf.goes_after.copy()
672      xf.outgoing = xf.goes_before.copy()
673
674    # We use an OrderedDict instead of just a set so that the output
675    # is repeatable; otherwise it would depend on the hash values of
676    # the transfer objects.
677    G = OrderedDict()
678    for xf in self.transfers:
679      G[xf] = None
680    s1 = deque()  # the left side of the sequence, built from left to right
681    s2 = deque()  # the right side of the sequence, built from right to left
682
683    while G:
684
685      # Put all sinks at the end of the sequence.
686      while True:
687        sinks = [u for u in G if not u.outgoing]
688        if not sinks: break
689        for u in sinks:
690          s2.appendleft(u)
691          del G[u]
692          for iu in u.incoming:
693            del iu.outgoing[u]
694
695      # Put all the sources at the beginning of the sequence.
696      while True:
697        sources = [u for u in G if not u.incoming]
698        if not sources: break
699        for u in sources:
700          s1.append(u)
701          del G[u]
702          for iu in u.outgoing:
703            del iu.incoming[u]
704
705      if not G: break
706
707      # Find the "best" vertex to put next.  "Best" is the one that
708      # maximizes the net difference in source blocks saved we get by
709      # pretending it's a source rather than a sink.
710
711      max_d = None
712      best_u = None
713      for u in G:
714        d = sum(u.outgoing.values()) - sum(u.incoming.values())
715        if best_u is None or d > max_d:
716          max_d = d
717          best_u = u
718
719      u = best_u
720      s1.append(u)
721      del G[u]
722      for iu in u.outgoing:
723        del iu.incoming[u]
724      for iu in u.incoming:
725        del iu.outgoing[u]
726
727    # Now record the sequence in the 'order' field of each transfer,
728    # and by rearranging self.transfers to be in the chosen sequence.
729
730    new_transfers = []
731    for x in itertools.chain(s1, s2):
732      x.order = len(new_transfers)
733      new_transfers.append(x)
734      del x.incoming
735      del x.outgoing
736
737    self.transfers = new_transfers
738
739  def GenerateDigraph(self):
740    print("Generating digraph...")
741    for a in self.transfers:
742      for b in self.transfers:
743        if a is b: continue
744
745        # If the blocks written by A are read by B, then B needs to go before A.
746        i = a.tgt_ranges.intersect(b.src_ranges)
747        if i:
748          if b.src_name == "__ZERO":
749            # the cost of removing source blocks for the __ZERO domain
750            # is (nearly) zero.
751            size = 0
752          else:
753            size = i.size()
754          b.goes_before[a] = size
755          a.goes_after[b] = size
756
757  def FindTransfers(self):
758    self.transfers = []
759    empty = RangeSet()
760    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
761      if tgt_fn == "__ZERO":
762        # the special "__ZERO" domain is all the blocks not contained
763        # in any file and that are filled with zeros.  We have a
764        # special transfer style for zero blocks.
765        src_ranges = self.src.file_map.get("__ZERO", empty)
766        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
767                 "zero", self.transfers)
768        continue
769
770      elif tgt_fn in self.src.file_map:
771        # Look for an exact pathname match in the source.
772        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
773                 "diff", self.transfers)
774        continue
775
776      b = os.path.basename(tgt_fn)
777      if b in self.src_basenames:
778        # Look for an exact basename match in the source.
779        src_fn = self.src_basenames[b]
780        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
781                 "diff", self.transfers)
782        continue
783
784      b = re.sub("[0-9]+", "#", b)
785      if b in self.src_numpatterns:
786        # Look for a 'number pattern' match (a basename match after
787        # all runs of digits are replaced by "#").  (This is useful
788        # for .so files that contain version numbers in the filename
789        # that get bumped.)
790        src_fn = self.src_numpatterns[b]
791        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
792                 "diff", self.transfers)
793        continue
794
795      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
796
797  def AbbreviateSourceNames(self):
798    self.src_basenames = {}
799    self.src_numpatterns = {}
800
801    for k in self.src.file_map.keys():
802      b = os.path.basename(k)
803      self.src_basenames[b] = k
804      b = re.sub("[0-9]+", "#", b)
805      self.src_numpatterns[b] = k
806
807  @staticmethod
808  def AssertPartition(total, seq):
809    """Assert that all the RangeSets in 'seq' form a partition of the
810    'total' RangeSet (ie, they are nonintersecting and their union
811    equals 'total')."""
812    so_far = RangeSet()
813    for i in seq:
814      assert not so_far.overlaps(i)
815      so_far = so_far.union(i)
816    assert so_far == total
817