blockimgdiff.py revision 007979ee7543a396d97b3e9ada21aca44d503597
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import pprint
24import re
25import subprocess
26import sys
27import threading
28import tempfile
29
30from rangelib import *
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72class EmptyImage(object):
73  """A zero-length image."""
74  blocksize = 4096
75  care_map = RangeSet()
76  clobbered_blocks = RangeSet()
77  total_blocks = 0
78  file_map = {}
79  def ReadRangeSet(self, ranges):
80    return ()
81  def TotalSha1(self):
82    return sha1().hexdigest()
83
84
85class DataImage(object):
86  """An image wrapped around a single string of data."""
87
88  def __init__(self, data, trim=False, pad=False):
89    self.data = data
90    self.blocksize = 4096
91
92    assert not (trim and pad)
93
94    partial = len(self.data) % self.blocksize
95    if partial > 0:
96      if trim:
97        self.data = self.data[:-partial]
98      elif pad:
99        self.data += '\0' * (self.blocksize - partial)
100      else:
101        raise ValueError(("data for DataImage must be multiple of %d bytes "
102                          "unless trim or pad is specified") %
103                         (self.blocksize,))
104
105    assert len(self.data) % self.blocksize == 0
106
107    self.total_blocks = len(self.data) / self.blocksize
108    self.care_map = RangeSet(data=(0, self.total_blocks))
109    self.clobbered_blocks = RangeSet()
110
111    zero_blocks = []
112    nonzero_blocks = []
113    reference = '\0' * self.blocksize
114
115    for i in range(self.total_blocks):
116      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
117      if d == reference:
118        zero_blocks.append(i)
119        zero_blocks.append(i+1)
120      else:
121        nonzero_blocks.append(i)
122        nonzero_blocks.append(i+1)
123
124    self.file_map = {"__ZERO": RangeSet(zero_blocks),
125                     "__NONZERO": RangeSet(nonzero_blocks)}
126
127  def ReadRangeSet(self, ranges):
128    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
129
130  def TotalSha1(self):
131    # DataImage always carries empty clobbered_blocks.
132    assert self.clobbered_blocks.size() == 0
133    return sha1(self.data).hexdigest()
134
135
136class Transfer(object):
137  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
138    self.tgt_name = tgt_name
139    self.src_name = src_name
140    self.tgt_ranges = tgt_ranges
141    self.src_ranges = src_ranges
142    self.style = style
143    self.intact = (getattr(tgt_ranges, "monotonic", False) and
144                   getattr(src_ranges, "monotonic", False))
145    self.goes_before = {}
146    self.goes_after = {}
147
148    self.stash_before = []
149    self.use_stash = []
150
151    self.id = len(by_id)
152    by_id.append(self)
153
154  def NetStashChange(self):
155    return (sum(sr.size() for (_, sr) in self.stash_before) -
156            sum(sr.size() for (_, sr) in self.use_stash))
157
158  def __str__(self):
159    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
160            " to " + str(self.tgt_ranges) + ">")
161
162
163# BlockImageDiff works on two image objects.  An image object is
164# anything that provides the following attributes:
165#
166#    blocksize: the size in bytes of a block, currently must be 4096.
167#
168#    total_blocks: the total size of the partition/image, in blocks.
169#
170#    care_map: a RangeSet containing which blocks (in the range [0,
171#      total_blocks) we actually care about; i.e. which blocks contain
172#      data.
173#
174#    file_map: a dict that partitions the blocks contained in care_map
175#      into smaller domains that are useful for doing diffs on.
176#      (Typically a domain is a file, and the key in file_map is the
177#      pathname.)
178#
179#    clobbered_blocks: a RangeSet containing which blocks contain data
180#      but may be altered by the FS. They need to be excluded when
181#      verifying the partition integrity.
182#
183#    ReadRangeSet(): a function that takes a RangeSet and returns the
184#      data contained in the image blocks of that RangeSet.  The data
185#      is returned as a list or tuple of strings; concatenating the
186#      elements together should produce the requested data.
187#      Implementations are free to break up the data into list/tuple
188#      elements in any way that is convenient.
189#
190#    TotalSha1(): a function that returns (as a hex string) the SHA-1
191#      hash of all the data in the image (ie, all the blocks in the
192#      care_map minus clobbered_blocks).
193#
194# When creating a BlockImageDiff, the src image may be None, in which
195# case the list of transfers produced will never read from the
196# original image.
197
198class BlockImageDiff(object):
199  def __init__(self, tgt, src=None, threads=None, version=2):
200    if threads is None:
201      threads = multiprocessing.cpu_count() // 2
202      if threads == 0: threads = 1
203    self.threads = threads
204    self.version = version
205
206    assert version in (1, 2)
207
208    self.tgt = tgt
209    if src is None:
210      src = EmptyImage()
211    self.src = src
212
213    # The updater code that installs the patch always uses 4k blocks.
214    assert tgt.blocksize == 4096
215    assert src.blocksize == 4096
216
217    # The range sets in each filemap should comprise a partition of
218    # the care map.
219    self.AssertPartition(src.care_map, src.file_map.values())
220    self.AssertPartition(tgt.care_map, tgt.file_map.values())
221
222  def Compute(self, prefix):
223    # When looking for a source file to use as the diff input for a
224    # target file, we try:
225    #   1) an exact path match if available, otherwise
226    #   2) a exact basename match if available, otherwise
227    #   3) a basename match after all runs of digits are replaced by
228    #      "#" if available, otherwise
229    #   4) we have no source for this target.
230    self.AbbreviateSourceNames()
231    self.FindTransfers()
232
233    # Find the ordering dependencies among transfers (this is O(n^2)
234    # in the number of transfers).
235    self.GenerateDigraph()
236    # Find a sequence of transfers that satisfies as many ordering
237    # dependencies as possible (heuristically).
238    self.FindVertexSequence()
239    # Fix up the ordering dependencies that the sequence didn't
240    # satisfy.
241    if self.version == 1:
242      self.RemoveBackwardEdges()
243    else:
244      self.ReverseBackwardEdges()
245      self.ImproveVertexSequence()
246
247    # Double-check our work.
248    self.AssertSequenceGood()
249
250    self.ComputePatches(prefix)
251    self.WriteTransfers(prefix)
252
253  def WriteTransfers(self, prefix):
254    out = []
255
256    total = 0
257    performs_read = False
258
259    stashes = {}
260    stashed_blocks = 0
261    max_stashed_blocks = 0
262
263    free_stash_ids = []
264    next_stash_id = 0
265
266    for xf in self.transfers:
267
268      if self.version < 2:
269        assert not xf.stash_before
270        assert not xf.use_stash
271
272      for s, sr in xf.stash_before:
273        assert s not in stashes
274        if free_stash_ids:
275          sid = heapq.heappop(free_stash_ids)
276        else:
277          sid = next_stash_id
278          next_stash_id += 1
279        stashes[s] = sid
280        stashed_blocks += sr.size()
281        out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
282
283      if stashed_blocks > max_stashed_blocks:
284        max_stashed_blocks = stashed_blocks
285
286      if self.version == 1:
287        src_string = xf.src_ranges.to_string_raw()
288      elif self.version == 2:
289
290        #   <# blocks> <src ranges>
291        #     OR
292        #   <# blocks> <src ranges> <src locs> <stash refs...>
293        #     OR
294        #   <# blocks> - <stash refs...>
295
296        size = xf.src_ranges.size()
297        src_string = [str(size)]
298
299        unstashed_src_ranges = xf.src_ranges
300        mapped_stashes = []
301        for s, sr in xf.use_stash:
302          sid = stashes.pop(s)
303          stashed_blocks -= sr.size()
304          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
305          sr = xf.src_ranges.map_within(sr)
306          mapped_stashes.append(sr)
307          src_string.append("%d:%s" % (sid, sr.to_string_raw()))
308          heapq.heappush(free_stash_ids, sid)
309
310        if unstashed_src_ranges:
311          src_string.insert(1, unstashed_src_ranges.to_string_raw())
312          if xf.use_stash:
313            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
314            src_string.insert(2, mapped_unstashed.to_string_raw())
315            mapped_stashes.append(mapped_unstashed)
316            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
317        else:
318          src_string.insert(1, "-")
319          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
320
321        src_string = " ".join(src_string)
322
323      # both versions:
324      #   zero <rangeset>
325      #   new <rangeset>
326      #   erase <rangeset>
327      #
328      # version 1:
329      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
330      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
331      #   move <src rangeset> <tgt rangeset>
332      #
333      # version 2:
334      #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
335      #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
336      #   move <tgt rangeset> <src_string>
337
338      tgt_size = xf.tgt_ranges.size()
339
340      if xf.style == "new":
341        assert xf.tgt_ranges
342        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
343        total += tgt_size
344      elif xf.style == "move":
345        performs_read = True
346        assert xf.tgt_ranges
347        assert xf.src_ranges.size() == tgt_size
348        if xf.src_ranges != xf.tgt_ranges:
349          if self.version == 1:
350            out.append("%s %s %s\n" % (
351                xf.style,
352                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
353          elif self.version == 2:
354            out.append("%s %s %s\n" % (
355                xf.style,
356                xf.tgt_ranges.to_string_raw(), src_string))
357          total += tgt_size
358      elif xf.style in ("bsdiff", "imgdiff"):
359        performs_read = True
360        assert xf.tgt_ranges
361        assert xf.src_ranges
362        if self.version == 1:
363          out.append("%s %d %d %s %s\n" % (
364              xf.style, xf.patch_start, xf.patch_len,
365              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
366        elif self.version == 2:
367          out.append("%s %d %d %s %s\n" % (
368              xf.style, xf.patch_start, xf.patch_len,
369              xf.tgt_ranges.to_string_raw(), src_string))
370        total += tgt_size
371      elif xf.style == "zero":
372        assert xf.tgt_ranges
373        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
374        if to_zero:
375          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
376          total += to_zero.size()
377      else:
378        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
379
380
381      # sanity check: abort if we're going to need more than 512 MB if
382      # stash space
383      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
384
385    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
386    if performs_read:
387      # if some of the original data is used, then at the end we'll
388      # erase all the blocks on the partition that don't contain data
389      # in the new image.
390      new_dontcare = all_tgt.subtract(self.tgt.care_map)
391      if new_dontcare:
392        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
393    else:
394      # if nothing is read (ie, this is a full OTA), then we can start
395      # by erasing the entire partition.
396      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
397
398    out.insert(0, "%d\n" % (self.version,))   # format version number
399    out.insert(1, str(total) + "\n")
400    if self.version >= 2:
401      # version 2 only: after the total block count, we give the number
402      # of stash slots needed, and the maximum size needed (in blocks)
403      out.insert(2, str(next_stash_id) + "\n")
404      out.insert(3, str(max_stashed_blocks) + "\n")
405
406    with open(prefix + ".transfer.list", "wb") as f:
407      for i in out:
408        f.write(i)
409
410    if self.version >= 2:
411      print("max stashed blocks: %d  (%d bytes)\n" % (
412          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
413
414  def ComputePatches(self, prefix):
415    print("Reticulating splines...")
416    diff_q = []
417    patch_num = 0
418    with open(prefix + ".new.dat", "wb") as new_f:
419      for xf in self.transfers:
420        if xf.style == "zero":
421          pass
422        elif xf.style == "new":
423          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
424            new_f.write(piece)
425        elif xf.style == "diff":
426          src = self.src.ReadRangeSet(xf.src_ranges)
427          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
428
429          # We can't compare src and tgt directly because they may have
430          # the same content but be broken up into blocks differently, eg:
431          #
432          #    ["he", "llo"]  vs  ["h", "ello"]
433          #
434          # We want those to compare equal, ideally without having to
435          # actually concatenate the strings (these may be tens of
436          # megabytes).
437
438          src_sha1 = sha1()
439          for p in src:
440            src_sha1.update(p)
441          tgt_sha1 = sha1()
442          tgt_size = 0
443          for p in tgt:
444            tgt_sha1.update(p)
445            tgt_size += len(p)
446
447          if src_sha1.digest() == tgt_sha1.digest():
448            # These are identical; we don't need to generate a patch,
449            # just issue copy commands on the device.
450            xf.style = "move"
451          else:
452            # For files in zip format (eg, APKs, JARs, etc.) we would
453            # like to use imgdiff -z if possible (because it usually
454            # produces significantly smaller patches than bsdiff).
455            # This is permissible if:
456            #
457            #  - the source and target files are monotonic (ie, the
458            #    data is stored with blocks in increasing order), and
459            #  - we haven't removed any blocks from the source set.
460            #
461            # If these conditions are satisfied then appending all the
462            # blocks in the set together in order will produce a valid
463            # zip file (plus possibly extra zeros in the last block),
464            # which is what imgdiff needs to operate.  (imgdiff is
465            # fine with extra zeros at the end of the file.)
466            imgdiff = (xf.intact and
467                       xf.tgt_name.split(".")[-1].lower()
468                       in ("apk", "jar", "zip"))
469            xf.style = "imgdiff" if imgdiff else "bsdiff"
470            diff_q.append((tgt_size, src, tgt, xf, patch_num))
471            patch_num += 1
472
473        else:
474          assert False, "unknown style " + xf.style
475
476    if diff_q:
477      if self.threads > 1:
478        print("Computing patches (using %d threads)..." % (self.threads,))
479      else:
480        print("Computing patches...")
481      diff_q.sort()
482
483      patches = [None] * patch_num
484
485      lock = threading.Lock()
486      def diff_worker():
487        while True:
488          with lock:
489            if not diff_q: return
490            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
491          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
492          size = len(patch)
493          with lock:
494            patches[patchnum] = (patch, xf)
495            print("%10d %10d (%6.2f%%) %7s %s" % (
496                size, tgt_size, size * 100.0 / tgt_size, xf.style,
497                xf.tgt_name if xf.tgt_name == xf.src_name else (
498                    xf.tgt_name + " (from " + xf.src_name + ")")))
499
500      threads = [threading.Thread(target=diff_worker)
501                 for i in range(self.threads)]
502      for th in threads:
503        th.start()
504      while threads:
505        threads.pop().join()
506    else:
507      patches = []
508
509    p = 0
510    with open(prefix + ".patch.dat", "wb") as patch_f:
511      for patch, xf in patches:
512        xf.patch_start = p
513        xf.patch_len = len(patch)
514        patch_f.write(patch)
515        p += len(patch)
516
517  def AssertSequenceGood(self):
518    # Simulate the sequences of transfers we will output, and check that:
519    # - we never read a block after writing it, and
520    # - we write every block we care about exactly once.
521
522    # Start with no blocks having been touched yet.
523    touched = RangeSet()
524
525    # Imagine processing the transfers in order.
526    for xf in self.transfers:
527      # Check that the input blocks for this transfer haven't yet been touched.
528
529      x = xf.src_ranges
530      if self.version >= 2:
531        for _, sr in xf.use_stash:
532          x = x.subtract(sr)
533
534      assert not touched.overlaps(x)
535      # Check that the output blocks for this transfer haven't yet been touched.
536      assert not touched.overlaps(xf.tgt_ranges)
537      # Touch all the blocks written by this transfer.
538      touched = touched.union(xf.tgt_ranges)
539
540    # Check that we've written every target block.
541    assert touched == self.tgt.care_map
542
543  def ImproveVertexSequence(self):
544    print("Improving vertex order...")
545
546    # At this point our digraph is acyclic; we reversed any edges that
547    # were backwards in the heuristically-generated sequence.  The
548    # previously-generated order is still acceptable, but we hope to
549    # find a better order that needs less memory for stashed data.
550    # Now we do a topological sort to generate a new vertex order,
551    # using a greedy algorithm to choose which vertex goes next
552    # whenever we have a choice.
553
554    # Make a copy of the edge set; this copy will get destroyed by the
555    # algorithm.
556    for xf in self.transfers:
557      xf.incoming = xf.goes_after.copy()
558      xf.outgoing = xf.goes_before.copy()
559
560    L = []   # the new vertex order
561
562    # S is the set of sources in the remaining graph; we always choose
563    # the one that leaves the least amount of stashed data after it's
564    # executed.
565    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
566         if not u.incoming]
567    heapq.heapify(S)
568
569    while S:
570      _, _, xf = heapq.heappop(S)
571      L.append(xf)
572      for u in xf.outgoing:
573        del u.incoming[xf]
574        if not u.incoming:
575          heapq.heappush(S, (u.NetStashChange(), u.order, u))
576
577    # if this fails then our graph had a cycle.
578    assert len(L) == len(self.transfers)
579
580    self.transfers = L
581    for i, xf in enumerate(L):
582      xf.order = i
583
584  def RemoveBackwardEdges(self):
585    print("Removing backward edges...")
586    in_order = 0
587    out_of_order = 0
588    lost_source = 0
589
590    for xf in self.transfers:
591      lost = 0
592      size = xf.src_ranges.size()
593      for u in xf.goes_before:
594        # xf should go before u
595        if xf.order < u.order:
596          # it does, hurray!
597          in_order += 1
598        else:
599          # it doesn't, boo.  trim the blocks that u writes from xf's
600          # source, so that xf can go after u.
601          out_of_order += 1
602          assert xf.src_ranges.overlaps(u.tgt_ranges)
603          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
604          xf.intact = False
605
606      if xf.style == "diff" and not xf.src_ranges:
607        # nothing left to diff from; treat as new data
608        xf.style = "new"
609
610      lost = size - xf.src_ranges.size()
611      lost_source += lost
612
613    print(("  %d/%d dependencies (%.2f%%) were violated; "
614           "%d source blocks removed.") %
615          (out_of_order, in_order + out_of_order,
616           (out_of_order * 100.0 / (in_order + out_of_order))
617           if (in_order + out_of_order) else 0.0,
618           lost_source))
619
620  def ReverseBackwardEdges(self):
621    print("Reversing backward edges...")
622    in_order = 0
623    out_of_order = 0
624    stashes = 0
625    stash_size = 0
626
627    for xf in self.transfers:
628      lost = 0
629      size = xf.src_ranges.size()
630      for u in xf.goes_before.copy():
631        # xf should go before u
632        if xf.order < u.order:
633          # it does, hurray!
634          in_order += 1
635        else:
636          # it doesn't, boo.  modify u to stash the blocks that it
637          # writes that xf wants to read, and then require u to go
638          # before xf.
639          out_of_order += 1
640
641          overlap = xf.src_ranges.intersect(u.tgt_ranges)
642          assert overlap
643
644          u.stash_before.append((stashes, overlap))
645          xf.use_stash.append((stashes, overlap))
646          stashes += 1
647          stash_size += overlap.size()
648
649          # reverse the edge direction; now xf must go after u
650          del xf.goes_before[u]
651          del u.goes_after[xf]
652          xf.goes_after[u] = None    # value doesn't matter
653          u.goes_before[xf] = None
654
655    print(("  %d/%d dependencies (%.2f%%) were violated; "
656           "%d source blocks stashed.") %
657          (out_of_order, in_order + out_of_order,
658           (out_of_order * 100.0 / (in_order + out_of_order))
659           if (in_order + out_of_order) else 0.0,
660           stash_size))
661
662  def FindVertexSequence(self):
663    print("Finding vertex sequence...")
664
665    # This is based on "A Fast & Effective Heuristic for the Feedback
666    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
667    # it as starting with the digraph G and moving all the vertices to
668    # be on a horizontal line in some order, trying to minimize the
669    # number of edges that end up pointing to the left.  Left-pointing
670    # edges will get removed to turn the digraph into a DAG.  In this
671    # case each edge has a weight which is the number of source blocks
672    # we'll lose if that edge is removed; we try to minimize the total
673    # weight rather than just the number of edges.
674
675    # Make a copy of the edge set; this copy will get destroyed by the
676    # algorithm.
677    for xf in self.transfers:
678      xf.incoming = xf.goes_after.copy()
679      xf.outgoing = xf.goes_before.copy()
680
681    # We use an OrderedDict instead of just a set so that the output
682    # is repeatable; otherwise it would depend on the hash values of
683    # the transfer objects.
684    G = OrderedDict()
685    for xf in self.transfers:
686      G[xf] = None
687    s1 = deque()  # the left side of the sequence, built from left to right
688    s2 = deque()  # the right side of the sequence, built from right to left
689
690    while G:
691
692      # Put all sinks at the end of the sequence.
693      while True:
694        sinks = [u for u in G if not u.outgoing]
695        if not sinks: break
696        for u in sinks:
697          s2.appendleft(u)
698          del G[u]
699          for iu in u.incoming:
700            del iu.outgoing[u]
701
702      # Put all the sources at the beginning of the sequence.
703      while True:
704        sources = [u for u in G if not u.incoming]
705        if not sources: break
706        for u in sources:
707          s1.append(u)
708          del G[u]
709          for iu in u.outgoing:
710            del iu.incoming[u]
711
712      if not G: break
713
714      # Find the "best" vertex to put next.  "Best" is the one that
715      # maximizes the net difference in source blocks saved we get by
716      # pretending it's a source rather than a sink.
717
718      max_d = None
719      best_u = None
720      for u in G:
721        d = sum(u.outgoing.values()) - sum(u.incoming.values())
722        if best_u is None or d > max_d:
723          max_d = d
724          best_u = u
725
726      u = best_u
727      s1.append(u)
728      del G[u]
729      for iu in u.outgoing:
730        del iu.incoming[u]
731      for iu in u.incoming:
732        del iu.outgoing[u]
733
734    # Now record the sequence in the 'order' field of each transfer,
735    # and by rearranging self.transfers to be in the chosen sequence.
736
737    new_transfers = []
738    for x in itertools.chain(s1, s2):
739      x.order = len(new_transfers)
740      new_transfers.append(x)
741      del x.incoming
742      del x.outgoing
743
744    self.transfers = new_transfers
745
746  def GenerateDigraph(self):
747    print("Generating digraph...")
748    for a in self.transfers:
749      for b in self.transfers:
750        if a is b: continue
751
752        # If the blocks written by A are read by B, then B needs to go before A.
753        i = a.tgt_ranges.intersect(b.src_ranges)
754        if i:
755          if b.src_name == "__ZERO":
756            # the cost of removing source blocks for the __ZERO domain
757            # is (nearly) zero.
758            size = 0
759          else:
760            size = i.size()
761          b.goes_before[a] = size
762          a.goes_after[b] = size
763
764  def FindTransfers(self):
765    self.transfers = []
766    empty = RangeSet()
767    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
768      if tgt_fn == "__ZERO":
769        # the special "__ZERO" domain is all the blocks not contained
770        # in any file and that are filled with zeros.  We have a
771        # special transfer style for zero blocks.
772        src_ranges = self.src.file_map.get("__ZERO", empty)
773        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
774                 "zero", self.transfers)
775        continue
776
777      elif tgt_fn == "__COPY":
778        # "__COPY" domain includes all the blocks not contained in any
779        # file and that need to be copied unconditionally to the target.
780        print("FindTransfers: new", tgt_ranges);
781        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
782        continue
783
784      elif tgt_fn in self.src.file_map:
785        # Look for an exact pathname match in the source.
786        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
787                 "diff", self.transfers)
788        continue
789
790      b = os.path.basename(tgt_fn)
791      if b in self.src_basenames:
792        # Look for an exact basename match in the source.
793        src_fn = self.src_basenames[b]
794        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
795                 "diff", self.transfers)
796        continue
797
798      b = re.sub("[0-9]+", "#", b)
799      if b in self.src_numpatterns:
800        # Look for a 'number pattern' match (a basename match after
801        # all runs of digits are replaced by "#").  (This is useful
802        # for .so files that contain version numbers in the filename
803        # that get bumped.)
804        src_fn = self.src_numpatterns[b]
805        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
806                 "diff", self.transfers)
807        continue
808
809      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
810
811  def AbbreviateSourceNames(self):
812    self.src_basenames = {}
813    self.src_numpatterns = {}
814
815    for k in self.src.file_map.keys():
816      b = os.path.basename(k)
817      self.src_basenames[b] = k
818      b = re.sub("[0-9]+", "#", b)
819      self.src_numpatterns[b] = k
820
821  @staticmethod
822  def AssertPartition(total, seq):
823    """Assert that all the RangeSets in 'seq' form a partition of the
824    'total' RangeSet (ie, they are nonintersecting and their union
825    equals 'total')."""
826    so_far = RangeSet()
827    for i in seq:
828      assert not so_far.overlaps(i)
829      so_far = so_far.union(i)
830    assert so_far == total
831