blockimgdiff.py revision e985f6f4d84e12ac3e189156f1260e008ef990d0
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import pprint
24import re
25import subprocess
26import sys
27import threading
28import tempfile
29
30from rangelib import *
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72class EmptyImage(object):
73  """A zero-length image."""
74  blocksize = 4096
75  care_map = RangeSet()
76  total_blocks = 0
77  file_map = {}
78  def ReadRangeSet(self, ranges):
79    return ()
80  def TotalSha1(self):
81    return sha1().hexdigest()
82
83
84class DataImage(object):
85  """An image wrapped around a single string of data."""
86
87  def __init__(self, data, trim=False, pad=False):
88    self.data = data
89    self.blocksize = 4096
90
91    assert not (trim and pad)
92
93    partial = len(self.data) % self.blocksize
94    if partial > 0:
95      if trim:
96        self.data = self.data[:-partial]
97      elif pad:
98        self.data += '\0' * (self.blocksize - partial)
99      else:
100        raise ValueError(("data for DataImage must be multiple of %d bytes "
101                          "unless trim or pad is specified") %
102                         (self.blocksize,))
103
104    assert len(self.data) % self.blocksize == 0
105
106    self.total_blocks = len(self.data) / self.blocksize
107    self.care_map = RangeSet(data=(0, self.total_blocks))
108
109    zero_blocks = []
110    nonzero_blocks = []
111    reference = '\0' * self.blocksize
112
113    for i in range(self.total_blocks):
114      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
115      if d == reference:
116        zero_blocks.append(i)
117        zero_blocks.append(i+1)
118      else:
119        nonzero_blocks.append(i)
120        nonzero_blocks.append(i+1)
121
122    self.file_map = {"__ZERO": RangeSet(zero_blocks),
123                     "__NONZERO": RangeSet(nonzero_blocks)}
124
125  def ReadRangeSet(self, ranges):
126    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
127
128  def TotalSha1(self):
129    if not hasattr(self, "sha1"):
130      self.sha1 = sha1(self.data).hexdigest()
131    return self.sha1
132
133
134class Transfer(object):
135  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
136    self.tgt_name = tgt_name
137    self.src_name = src_name
138    self.tgt_ranges = tgt_ranges
139    self.src_ranges = src_ranges
140    self.style = style
141    self.intact = (getattr(tgt_ranges, "monotonic", False) and
142                   getattr(src_ranges, "monotonic", False))
143    self.goes_before = {}
144    self.goes_after = {}
145
146    self.stash_before = []
147    self.use_stash = []
148
149    self.id = len(by_id)
150    by_id.append(self)
151
152  def NetStashChange(self):
153    return (sum(sr.size() for (_, sr) in self.stash_before) -
154            sum(sr.size() for (_, sr) in self.use_stash))
155
156  def __str__(self):
157    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
158            " to " + str(self.tgt_ranges) + ">")
159
160
161# BlockImageDiff works on two image objects.  An image object is
162# anything that provides the following attributes:
163#
164#    blocksize: the size in bytes of a block, currently must be 4096.
165#
166#    total_blocks: the total size of the partition/image, in blocks.
167#
168#    care_map: a RangeSet containing which blocks (in the range [0,
169#      total_blocks) we actually care about; i.e. which blocks contain
170#      data.
171#
172#    file_map: a dict that partitions the blocks contained in care_map
173#      into smaller domains that are useful for doing diffs on.
174#      (Typically a domain is a file, and the key in file_map is the
175#      pathname.)
176#
177#    ReadRangeSet(): a function that takes a RangeSet and returns the
178#      data contained in the image blocks of that RangeSet.  The data
179#      is returned as a list or tuple of strings; concatenating the
180#      elements together should produce the requested data.
181#      Implementations are free to break up the data into list/tuple
182#      elements in any way that is convenient.
183#
184#    TotalSha1(): a function that returns (as a hex string) the SHA-1
185#      hash of all the data in the image (ie, all the blocks in the
186#      care_map)
187#
188# When creating a BlockImageDiff, the src image may be None, in which
189# case the list of transfers produced will never read from the
190# original image.
191
192class BlockImageDiff(object):
193  def __init__(self, tgt, src=None, threads=None, version=2):
194    if threads is None:
195      threads = multiprocessing.cpu_count() // 2
196      if threads == 0: threads = 1
197    self.threads = threads
198    self.version = version
199
200    assert version in (1, 2)
201
202    self.tgt = tgt
203    if src is None:
204      src = EmptyImage()
205    self.src = src
206
207    # The updater code that installs the patch always uses 4k blocks.
208    assert tgt.blocksize == 4096
209    assert src.blocksize == 4096
210
211    # The range sets in each filemap should comprise a partition of
212    # the care map.
213    self.AssertPartition(src.care_map, src.file_map.values())
214    self.AssertPartition(tgt.care_map, tgt.file_map.values())
215
216  def Compute(self, prefix):
217    # When looking for a source file to use as the diff input for a
218    # target file, we try:
219    #   1) an exact path match if available, otherwise
220    #   2) a exact basename match if available, otherwise
221    #   3) a basename match after all runs of digits are replaced by
222    #      "#" if available, otherwise
223    #   4) we have no source for this target.
224    self.AbbreviateSourceNames()
225    self.FindTransfers()
226
227    # Find the ordering dependencies among transfers (this is O(n^2)
228    # in the number of transfers).
229    self.GenerateDigraph()
230    # Find a sequence of transfers that satisfies as many ordering
231    # dependencies as possible (heuristically).
232    self.FindVertexSequence()
233    # Fix up the ordering dependencies that the sequence didn't
234    # satisfy.
235    if self.version == 1:
236      self.RemoveBackwardEdges()
237    else:
238      self.ReverseBackwardEdges()
239      self.ImproveVertexSequence()
240
241    # Double-check our work.
242    self.AssertSequenceGood()
243
244    self.ComputePatches(prefix)
245    self.WriteTransfers(prefix)
246
247  def WriteTransfers(self, prefix):
248    out = []
249
250    total = 0
251    performs_read = False
252
253    stashes = {}
254    stashed_blocks = 0
255    max_stashed_blocks = 0
256
257    free_stash_ids = []
258    next_stash_id = 0
259
260    for xf in self.transfers:
261
262      if self.version < 2:
263        assert not xf.stash_before
264        assert not xf.use_stash
265
266      for s, sr in xf.stash_before:
267        assert s not in stashes
268        if free_stash_ids:
269          sid = heapq.heappop(free_stash_ids)
270        else:
271          sid = next_stash_id
272          next_stash_id += 1
273        stashes[s] = sid
274        stashed_blocks += sr.size()
275        out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
276
277      if stashed_blocks > max_stashed_blocks:
278        max_stashed_blocks = stashed_blocks
279
280      if self.version == 1:
281        src_string = xf.src_ranges.to_string_raw()
282      elif self.version == 2:
283
284        #   <# blocks> <src ranges>
285        #     OR
286        #   <# blocks> <src ranges> <src locs> <stash refs...>
287        #     OR
288        #   <# blocks> - <stash refs...>
289
290        size = xf.src_ranges.size()
291        src_string = [str(size)]
292
293        unstashed_src_ranges = xf.src_ranges
294        mapped_stashes = []
295        for s, sr in xf.use_stash:
296          sid = stashes.pop(s)
297          stashed_blocks -= sr.size()
298          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
299          sr = xf.src_ranges.map_within(sr)
300          mapped_stashes.append(sr)
301          src_string.append("%d:%s" % (sid, sr.to_string_raw()))
302          heapq.heappush(free_stash_ids, sid)
303
304        if unstashed_src_ranges:
305          src_string.insert(1, unstashed_src_ranges.to_string_raw())
306          if xf.use_stash:
307            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
308            src_string.insert(2, mapped_unstashed.to_string_raw())
309            mapped_stashes.append(mapped_unstashed)
310            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
311        else:
312          src_string.insert(1, "-")
313          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
314
315        src_string = " ".join(src_string)
316
317      # both versions:
318      #   zero <rangeset>
319      #   new <rangeset>
320      #   erase <rangeset>
321      #
322      # version 1:
323      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
324      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
325      #   move <src rangeset> <tgt rangeset>
326      #
327      # version 2:
328      #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
329      #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
330      #   move <tgt rangeset> <src_string>
331
332      tgt_size = xf.tgt_ranges.size()
333
334      if xf.style == "new":
335        assert xf.tgt_ranges
336        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
337        total += tgt_size
338      elif xf.style == "move":
339        performs_read = True
340        assert xf.tgt_ranges
341        assert xf.src_ranges.size() == tgt_size
342        if xf.src_ranges != xf.tgt_ranges:
343          if self.version == 1:
344            out.append("%s %s %s\n" % (
345                xf.style,
346                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
347          elif self.version == 2:
348            out.append("%s %s %s\n" % (
349                xf.style,
350                xf.tgt_ranges.to_string_raw(), src_string))
351          total += tgt_size
352      elif xf.style in ("bsdiff", "imgdiff"):
353        performs_read = True
354        assert xf.tgt_ranges
355        assert xf.src_ranges
356        if self.version == 1:
357          out.append("%s %d %d %s %s\n" % (
358              xf.style, xf.patch_start, xf.patch_len,
359              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
360        elif self.version == 2:
361          out.append("%s %d %d %s %s\n" % (
362              xf.style, xf.patch_start, xf.patch_len,
363              xf.tgt_ranges.to_string_raw(), src_string))
364        total += tgt_size
365      elif xf.style == "zero":
366        assert xf.tgt_ranges
367        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
368        if to_zero:
369          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
370          total += to_zero.size()
371      else:
372        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
373
374
375      # sanity check: abort if we're going to need more than 512 MB if
376      # stash space
377      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
378
379    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
380    if performs_read:
381      # if some of the original data is used, then at the end we'll
382      # erase all the blocks on the partition that don't contain data
383      # in the new image.
384      new_dontcare = all_tgt.subtract(self.tgt.care_map)
385      if new_dontcare:
386        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
387    else:
388      # if nothing is read (ie, this is a full OTA), then we can start
389      # by erasing the entire partition.
390      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
391
392    out.insert(0, "%d\n" % (self.version,))   # format version number
393    out.insert(1, str(total) + "\n")
394    if self.version >= 2:
395      # version 2 only: after the total block count, we give the number
396      # of stash slots needed, and the maximum size needed (in blocks)
397      out.insert(2, str(next_stash_id) + "\n")
398      out.insert(3, str(max_stashed_blocks) + "\n")
399
400    with open(prefix + ".transfer.list", "wb") as f:
401      for i in out:
402        f.write(i)
403
404    if self.version >= 2:
405      print("max stashed blocks: %d  (%d bytes)\n" % (
406          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
407
408  def ComputePatches(self, prefix):
409    print("Reticulating splines...")
410    diff_q = []
411    patch_num = 0
412    with open(prefix + ".new.dat", "wb") as new_f:
413      for xf in self.transfers:
414        if xf.style == "zero":
415          pass
416        elif xf.style == "new":
417          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
418            new_f.write(piece)
419        elif xf.style == "diff":
420          src = self.src.ReadRangeSet(xf.src_ranges)
421          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
422
423          # We can't compare src and tgt directly because they may have
424          # the same content but be broken up into blocks differently, eg:
425          #
426          #    ["he", "llo"]  vs  ["h", "ello"]
427          #
428          # We want those to compare equal, ideally without having to
429          # actually concatenate the strings (these may be tens of
430          # megabytes).
431
432          src_sha1 = sha1()
433          for p in src:
434            src_sha1.update(p)
435          tgt_sha1 = sha1()
436          tgt_size = 0
437          for p in tgt:
438            tgt_sha1.update(p)
439            tgt_size += len(p)
440
441          if src_sha1.digest() == tgt_sha1.digest():
442            # These are identical; we don't need to generate a patch,
443            # just issue copy commands on the device.
444            xf.style = "move"
445          else:
446            # For files in zip format (eg, APKs, JARs, etc.) we would
447            # like to use imgdiff -z if possible (because it usually
448            # produces significantly smaller patches than bsdiff).
449            # This is permissible if:
450            #
451            #  - the source and target files are monotonic (ie, the
452            #    data is stored with blocks in increasing order), and
453            #  - we haven't removed any blocks from the source set.
454            #
455            # If these conditions are satisfied then appending all the
456            # blocks in the set together in order will produce a valid
457            # zip file (plus possibly extra zeros in the last block),
458            # which is what imgdiff needs to operate.  (imgdiff is
459            # fine with extra zeros at the end of the file.)
460            imgdiff = (xf.intact and
461                       xf.tgt_name.split(".")[-1].lower()
462                       in ("apk", "jar", "zip"))
463            xf.style = "imgdiff" if imgdiff else "bsdiff"
464            diff_q.append((tgt_size, src, tgt, xf, patch_num))
465            patch_num += 1
466
467        else:
468          assert False, "unknown style " + xf.style
469
470    if diff_q:
471      if self.threads > 1:
472        print("Computing patches (using %d threads)..." % (self.threads,))
473      else:
474        print("Computing patches...")
475      diff_q.sort()
476
477      patches = [None] * patch_num
478
479      lock = threading.Lock()
480      def diff_worker():
481        while True:
482          with lock:
483            if not diff_q: return
484            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
485          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
486          size = len(patch)
487          with lock:
488            patches[patchnum] = (patch, xf)
489            print("%10d %10d (%6.2f%%) %7s %s" % (
490                size, tgt_size, size * 100.0 / tgt_size, xf.style,
491                xf.tgt_name if xf.tgt_name == xf.src_name else (
492                    xf.tgt_name + " (from " + xf.src_name + ")")))
493
494      threads = [threading.Thread(target=diff_worker)
495                 for i in range(self.threads)]
496      for th in threads:
497        th.start()
498      while threads:
499        threads.pop().join()
500    else:
501      patches = []
502
503    p = 0
504    with open(prefix + ".patch.dat", "wb") as patch_f:
505      for patch, xf in patches:
506        xf.patch_start = p
507        xf.patch_len = len(patch)
508        patch_f.write(patch)
509        p += len(patch)
510
511  def AssertSequenceGood(self):
512    # Simulate the sequences of transfers we will output, and check that:
513    # - we never read a block after writing it, and
514    # - we write every block we care about exactly once.
515
516    # Start with no blocks having been touched yet.
517    touched = RangeSet()
518
519    # Imagine processing the transfers in order.
520    for xf in self.transfers:
521      # Check that the input blocks for this transfer haven't yet been touched.
522
523      x = xf.src_ranges
524      if self.version >= 2:
525        for _, sr in xf.use_stash:
526          x = x.subtract(sr)
527
528      assert not touched.overlaps(x)
529      # Check that the output blocks for this transfer haven't yet been touched.
530      assert not touched.overlaps(xf.tgt_ranges)
531      # Touch all the blocks written by this transfer.
532      touched = touched.union(xf.tgt_ranges)
533
534    # Check that we've written every target block.
535    assert touched == self.tgt.care_map
536
537  def ImproveVertexSequence(self):
538    print("Improving vertex order...")
539
540    # At this point our digraph is acyclic; we reversed any edges that
541    # were backwards in the heuristically-generated sequence.  The
542    # previously-generated order is still acceptable, but we hope to
543    # find a better order that needs less memory for stashed data.
544    # Now we do a topological sort to generate a new vertex order,
545    # using a greedy algorithm to choose which vertex goes next
546    # whenever we have a choice.
547
548    # Make a copy of the edge set; this copy will get destroyed by the
549    # algorithm.
550    for xf in self.transfers:
551      xf.incoming = xf.goes_after.copy()
552      xf.outgoing = xf.goes_before.copy()
553
554    L = []   # the new vertex order
555
556    # S is the set of sources in the remaining graph; we always choose
557    # the one that leaves the least amount of stashed data after it's
558    # executed.
559    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
560         if not u.incoming]
561    heapq.heapify(S)
562
563    while S:
564      _, _, xf = heapq.heappop(S)
565      L.append(xf)
566      for u in xf.outgoing:
567        del u.incoming[xf]
568        if not u.incoming:
569          heapq.heappush(S, (u.NetStashChange(), u.order, u))
570
571    # if this fails then our graph had a cycle.
572    assert len(L) == len(self.transfers)
573
574    self.transfers = L
575    for i, xf in enumerate(L):
576      xf.order = i
577
578  def RemoveBackwardEdges(self):
579    print("Removing backward edges...")
580    in_order = 0
581    out_of_order = 0
582    lost_source = 0
583
584    for xf in self.transfers:
585      lost = 0
586      size = xf.src_ranges.size()
587      for u in xf.goes_before:
588        # xf should go before u
589        if xf.order < u.order:
590          # it does, hurray!
591          in_order += 1
592        else:
593          # it doesn't, boo.  trim the blocks that u writes from xf's
594          # source, so that xf can go after u.
595          out_of_order += 1
596          assert xf.src_ranges.overlaps(u.tgt_ranges)
597          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
598          xf.intact = False
599
600      if xf.style == "diff" and not xf.src_ranges:
601        # nothing left to diff from; treat as new data
602        xf.style = "new"
603
604      lost = size - xf.src_ranges.size()
605      lost_source += lost
606
607    print(("  %d/%d dependencies (%.2f%%) were violated; "
608           "%d source blocks removed.") %
609          (out_of_order, in_order + out_of_order,
610           (out_of_order * 100.0 / (in_order + out_of_order))
611           if (in_order + out_of_order) else 0.0,
612           lost_source))
613
614  def ReverseBackwardEdges(self):
615    print("Reversing backward edges...")
616    in_order = 0
617    out_of_order = 0
618    stashes = 0
619    stash_size = 0
620
621    for xf in self.transfers:
622      lost = 0
623      size = xf.src_ranges.size()
624      for u in xf.goes_before.copy():
625        # xf should go before u
626        if xf.order < u.order:
627          # it does, hurray!
628          in_order += 1
629        else:
630          # it doesn't, boo.  modify u to stash the blocks that it
631          # writes that xf wants to read, and then require u to go
632          # before xf.
633          out_of_order += 1
634
635          overlap = xf.src_ranges.intersect(u.tgt_ranges)
636          assert overlap
637
638          u.stash_before.append((stashes, overlap))
639          xf.use_stash.append((stashes, overlap))
640          stashes += 1
641          stash_size += overlap.size()
642
643          # reverse the edge direction; now xf must go after u
644          del xf.goes_before[u]
645          del u.goes_after[xf]
646          xf.goes_after[u] = None    # value doesn't matter
647          u.goes_before[xf] = None
648
649    print(("  %d/%d dependencies (%.2f%%) were violated; "
650           "%d source blocks stashed.") %
651          (out_of_order, in_order + out_of_order,
652           (out_of_order * 100.0 / (in_order + out_of_order))
653           if (in_order + out_of_order) else 0.0,
654           stash_size))
655
656  def FindVertexSequence(self):
657    print("Finding vertex sequence...")
658
659    # This is based on "A Fast & Effective Heuristic for the Feedback
660    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
661    # it as starting with the digraph G and moving all the vertices to
662    # be on a horizontal line in some order, trying to minimize the
663    # number of edges that end up pointing to the left.  Left-pointing
664    # edges will get removed to turn the digraph into a DAG.  In this
665    # case each edge has a weight which is the number of source blocks
666    # we'll lose if that edge is removed; we try to minimize the total
667    # weight rather than just the number of edges.
668
669    # Make a copy of the edge set; this copy will get destroyed by the
670    # algorithm.
671    for xf in self.transfers:
672      xf.incoming = xf.goes_after.copy()
673      xf.outgoing = xf.goes_before.copy()
674
675    # We use an OrderedDict instead of just a set so that the output
676    # is repeatable; otherwise it would depend on the hash values of
677    # the transfer objects.
678    G = OrderedDict()
679    for xf in self.transfers:
680      G[xf] = None
681    s1 = deque()  # the left side of the sequence, built from left to right
682    s2 = deque()  # the right side of the sequence, built from right to left
683
684    while G:
685
686      # Put all sinks at the end of the sequence.
687      while True:
688        sinks = [u for u in G if not u.outgoing]
689        if not sinks: break
690        for u in sinks:
691          s2.appendleft(u)
692          del G[u]
693          for iu in u.incoming:
694            del iu.outgoing[u]
695
696      # Put all the sources at the beginning of the sequence.
697      while True:
698        sources = [u for u in G if not u.incoming]
699        if not sources: break
700        for u in sources:
701          s1.append(u)
702          del G[u]
703          for iu in u.outgoing:
704            del iu.incoming[u]
705
706      if not G: break
707
708      # Find the "best" vertex to put next.  "Best" is the one that
709      # maximizes the net difference in source blocks saved we get by
710      # pretending it's a source rather than a sink.
711
712      max_d = None
713      best_u = None
714      for u in G:
715        d = sum(u.outgoing.values()) - sum(u.incoming.values())
716        if best_u is None or d > max_d:
717          max_d = d
718          best_u = u
719
720      u = best_u
721      s1.append(u)
722      del G[u]
723      for iu in u.outgoing:
724        del iu.incoming[u]
725      for iu in u.incoming:
726        del iu.outgoing[u]
727
728    # Now record the sequence in the 'order' field of each transfer,
729    # and by rearranging self.transfers to be in the chosen sequence.
730
731    new_transfers = []
732    for x in itertools.chain(s1, s2):
733      x.order = len(new_transfers)
734      new_transfers.append(x)
735      del x.incoming
736      del x.outgoing
737
738    self.transfers = new_transfers
739
740  def GenerateDigraph(self):
741    print("Generating digraph...")
742    for a in self.transfers:
743      for b in self.transfers:
744        if a is b: continue
745
746        # If the blocks written by A are read by B, then B needs to go before A.
747        i = a.tgt_ranges.intersect(b.src_ranges)
748        if i:
749          if b.src_name == "__ZERO":
750            # the cost of removing source blocks for the __ZERO domain
751            # is (nearly) zero.
752            size = 0
753          else:
754            size = i.size()
755          b.goes_before[a] = size
756          a.goes_after[b] = size
757
758  def FindTransfers(self):
759    self.transfers = []
760    empty = RangeSet()
761    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
762      if tgt_fn == "__ZERO":
763        # the special "__ZERO" domain is all the blocks not contained
764        # in any file and that are filled with zeros.  We have a
765        # special transfer style for zero blocks.
766        src_ranges = self.src.file_map.get("__ZERO", empty)
767        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
768                 "zero", self.transfers)
769        continue
770
771      elif tgt_fn in self.src.file_map:
772        # Look for an exact pathname match in the source.
773        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
774                 "diff", self.transfers)
775        continue
776
777      b = os.path.basename(tgt_fn)
778      if b in self.src_basenames:
779        # Look for an exact basename match in the source.
780        src_fn = self.src_basenames[b]
781        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
782                 "diff", self.transfers)
783        continue
784
785      b = re.sub("[0-9]+", "#", b)
786      if b in self.src_numpatterns:
787        # Look for a 'number pattern' match (a basename match after
788        # all runs of digits are replaced by "#").  (This is useful
789        # for .so files that contain version numbers in the filename
790        # that get bumped.)
791        src_fn = self.src_numpatterns[b]
792        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
793                 "diff", self.transfers)
794        continue
795
796      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
797
798  def AbbreviateSourceNames(self):
799    self.src_basenames = {}
800    self.src_numpatterns = {}
801
802    for k in self.src.file_map.keys():
803      b = os.path.basename(k)
804      self.src_basenames[b] = k
805      b = re.sub("[0-9]+", "#", b)
806      self.src_numpatterns[b] = k
807
808  @staticmethod
809  def AssertPartition(total, seq):
810    """Assert that all the RangeSets in 'seq' form a partition of the
811    'total' RangeSet (ie, they are nonintersecting and their union
812    equals 'total')."""
813    so_far = RangeSet()
814    for i in seq:
815      assert not so_far.overlaps(i)
816      so_far = so_far.union(i)
817    assert so_far == total
818