blockimgdiff.py revision 3203bc10f0bc6a1acce8afeee6d8df1f0f64f72c
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import re
24import subprocess
25import threading
26import tempfile
27
28from rangelib import RangeSet
29
30
31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72
73class Image(object):
74  def ReadRangeSet(self, ranges):
75    raise NotImplementedError
76
77  def TotalSha1(self, include_clobbered_blocks=False):
78    raise NotImplementedError
79
80
81class EmptyImage(Image):
82  """A zero-length image."""
83  blocksize = 4096
84  care_map = RangeSet()
85  clobbered_blocks = RangeSet()
86  extended = RangeSet()
87  total_blocks = 0
88  file_map = {}
89  def ReadRangeSet(self, ranges):
90    return ()
91  def TotalSha1(self, include_clobbered_blocks=False):
92    # EmptyImage always carries empty clobbered_blocks, so
93    # include_clobbered_blocks can be ignored.
94    assert self.clobbered_blocks.size() == 0
95    return sha1().hexdigest()
96
97
98class DataImage(Image):
99  """An image wrapped around a single string of data."""
100
101  def __init__(self, data, trim=False, pad=False):
102    self.data = data
103    self.blocksize = 4096
104
105    assert not (trim and pad)
106
107    partial = len(self.data) % self.blocksize
108    padded = False
109    if partial > 0:
110      if trim:
111        self.data = self.data[:-partial]
112      elif pad:
113        self.data += '\0' * (self.blocksize - partial)
114        padded = True
115      else:
116        raise ValueError(("data for DataImage must be multiple of %d bytes "
117                          "unless trim or pad is specified") %
118                         (self.blocksize,))
119
120    assert len(self.data) % self.blocksize == 0
121
122    self.total_blocks = len(self.data) / self.blocksize
123    self.care_map = RangeSet(data=(0, self.total_blocks))
124    # When the last block is padded, we always write the whole block even for
125    # incremental OTAs. Because otherwise the last block may get skipped if
126    # unchanged for an incremental, but would fail the post-install
127    # verification if it has non-zero contents in the padding bytes.
128    # Bug: 23828506
129    if padded:
130      self.clobbered_blocks = RangeSet(
131          data=(self.total_blocks-1, self.total_blocks))
132    else:
133      self.clobbered_blocks = RangeSet()
134    self.extended = RangeSet()
135
136    zero_blocks = []
137    nonzero_blocks = []
138    reference = '\0' * self.blocksize
139
140    for i in range(self.total_blocks-1 if padded else self.total_blocks):
141      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
142      if d == reference:
143        zero_blocks.append(i)
144        zero_blocks.append(i+1)
145      else:
146        nonzero_blocks.append(i)
147        nonzero_blocks.append(i+1)
148
149    self.file_map = {"__ZERO": RangeSet(zero_blocks),
150                     "__NONZERO": RangeSet(nonzero_blocks)}
151
152    if self.clobbered_blocks:
153      self.file_map["__COPY"] = self.clobbered_blocks
154
155  def ReadRangeSet(self, ranges):
156    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
157
158  def TotalSha1(self, include_clobbered_blocks=False):
159    if not include_clobbered_blocks:
160      ranges = self.care_map.subtract(self.clobbered_blocks)
161      return sha1(self.ReadRangeSet(ranges)).hexdigest()
162    else:
163      return sha1(self.data).hexdigest()
164
165
166class Transfer(object):
167  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
168    self.tgt_name = tgt_name
169    self.src_name = src_name
170    self.tgt_ranges = tgt_ranges
171    self.src_ranges = src_ranges
172    self.style = style
173    self.intact = (getattr(tgt_ranges, "monotonic", False) and
174                   getattr(src_ranges, "monotonic", False))
175
176    # We use OrderedDict rather than dict so that the output is repeatable;
177    # otherwise it would depend on the hash values of the Transfer objects.
178    self.goes_before = OrderedDict()
179    self.goes_after = OrderedDict()
180
181    self.stash_before = []
182    self.use_stash = []
183
184    self.id = len(by_id)
185    by_id.append(self)
186
187  def NetStashChange(self):
188    return (sum(sr.size() for (_, sr) in self.stash_before) -
189            sum(sr.size() for (_, sr) in self.use_stash))
190
191  def __str__(self):
192    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
193            " to " + str(self.tgt_ranges) + ">")
194
195
196# BlockImageDiff works on two image objects.  An image object is
197# anything that provides the following attributes:
198#
199#    blocksize: the size in bytes of a block, currently must be 4096.
200#
201#    total_blocks: the total size of the partition/image, in blocks.
202#
203#    care_map: a RangeSet containing which blocks (in the range [0,
204#      total_blocks) we actually care about; i.e. which blocks contain
205#      data.
206#
207#    file_map: a dict that partitions the blocks contained in care_map
208#      into smaller domains that are useful for doing diffs on.
209#      (Typically a domain is a file, and the key in file_map is the
210#      pathname.)
211#
212#    clobbered_blocks: a RangeSet containing which blocks contain data
213#      but may be altered by the FS. They need to be excluded when
214#      verifying the partition integrity.
215#
216#    ReadRangeSet(): a function that takes a RangeSet and returns the
217#      data contained in the image blocks of that RangeSet.  The data
218#      is returned as a list or tuple of strings; concatenating the
219#      elements together should produce the requested data.
220#      Implementations are free to break up the data into list/tuple
221#      elements in any way that is convenient.
222#
223#    TotalSha1(): a function that returns (as a hex string) the SHA-1
224#      hash of all the data in the image (ie, all the blocks in the
225#      care_map minus clobbered_blocks, or including the clobbered
226#      blocks if include_clobbered_blocks is True).
227#
228# When creating a BlockImageDiff, the src image may be None, in which
229# case the list of transfers produced will never read from the
230# original image.
231
232class BlockImageDiff(object):
233  def __init__(self, tgt, src=None, threads=None, version=3):
234    if threads is None:
235      threads = multiprocessing.cpu_count() // 2
236      if threads == 0:
237        threads = 1
238    self.threads = threads
239    self.version = version
240    self.transfers = []
241    self.src_basenames = {}
242    self.src_numpatterns = {}
243
244    assert version in (1, 2, 3)
245
246    self.tgt = tgt
247    if src is None:
248      src = EmptyImage()
249    self.src = src
250
251    # The updater code that installs the patch always uses 4k blocks.
252    assert tgt.blocksize == 4096
253    assert src.blocksize == 4096
254
255    # The range sets in each filemap should comprise a partition of
256    # the care map.
257    self.AssertPartition(src.care_map, src.file_map.values())
258    self.AssertPartition(tgt.care_map, tgt.file_map.values())
259
260  def Compute(self, prefix):
261    # When looking for a source file to use as the diff input for a
262    # target file, we try:
263    #   1) an exact path match if available, otherwise
264    #   2) a exact basename match if available, otherwise
265    #   3) a basename match after all runs of digits are replaced by
266    #      "#" if available, otherwise
267    #   4) we have no source for this target.
268    self.AbbreviateSourceNames()
269    self.FindTransfers()
270
271    # Find the ordering dependencies among transfers (this is O(n^2)
272    # in the number of transfers).
273    self.GenerateDigraph()
274    # Find a sequence of transfers that satisfies as many ordering
275    # dependencies as possible (heuristically).
276    self.FindVertexSequence()
277    # Fix up the ordering dependencies that the sequence didn't
278    # satisfy.
279    if self.version == 1:
280      self.RemoveBackwardEdges()
281    else:
282      self.ReverseBackwardEdges()
283      self.ImproveVertexSequence()
284
285    # Double-check our work.
286    self.AssertSequenceGood()
287
288    self.ComputePatches(prefix)
289    self.WriteTransfers(prefix)
290
291  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
292    data = source.ReadRangeSet(ranges)
293    ctx = sha1()
294
295    for p in data:
296      ctx.update(p)
297
298    return ctx.hexdigest()
299
300  def WriteTransfers(self, prefix):
301    out = []
302
303    total = 0
304    performs_read = False
305
306    stashes = {}
307    stashed_blocks = 0
308    max_stashed_blocks = 0
309
310    free_stash_ids = []
311    next_stash_id = 0
312
313    for xf in self.transfers:
314
315      if self.version < 2:
316        assert not xf.stash_before
317        assert not xf.use_stash
318
319      for s, sr in xf.stash_before:
320        assert s not in stashes
321        if free_stash_ids:
322          sid = heapq.heappop(free_stash_ids)
323        else:
324          sid = next_stash_id
325          next_stash_id += 1
326        stashes[s] = sid
327        stashed_blocks += sr.size()
328        if self.version == 2:
329          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
330        else:
331          sh = self.HashBlocks(self.src, sr)
332          if sh in stashes:
333            stashes[sh] += 1
334          else:
335            stashes[sh] = 1
336            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
337
338      if stashed_blocks > max_stashed_blocks:
339        max_stashed_blocks = stashed_blocks
340
341      free_string = []
342
343      if self.version == 1:
344        src_str = xf.src_ranges.to_string_raw()
345      elif self.version >= 2:
346
347        #   <# blocks> <src ranges>
348        #     OR
349        #   <# blocks> <src ranges> <src locs> <stash refs...>
350        #     OR
351        #   <# blocks> - <stash refs...>
352
353        size = xf.src_ranges.size()
354        src_str = [str(size)]
355
356        unstashed_src_ranges = xf.src_ranges
357        mapped_stashes = []
358        for s, sr in xf.use_stash:
359          sid = stashes.pop(s)
360          stashed_blocks -= sr.size()
361          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
362          sh = self.HashBlocks(self.src, sr)
363          sr = xf.src_ranges.map_within(sr)
364          mapped_stashes.append(sr)
365          if self.version == 2:
366            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
367          else:
368            assert sh in stashes
369            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
370            stashes[sh] -= 1
371            if stashes[sh] == 0:
372              free_string.append("free %s\n" % (sh))
373              stashes.pop(sh)
374          heapq.heappush(free_stash_ids, sid)
375
376        if unstashed_src_ranges:
377          src_str.insert(1, unstashed_src_ranges.to_string_raw())
378          if xf.use_stash:
379            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
380            src_str.insert(2, mapped_unstashed.to_string_raw())
381            mapped_stashes.append(mapped_unstashed)
382            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
383        else:
384          src_str.insert(1, "-")
385          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
386
387        src_str = " ".join(src_str)
388
389      # all versions:
390      #   zero <rangeset>
391      #   new <rangeset>
392      #   erase <rangeset>
393      #
394      # version 1:
395      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
396      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
397      #   move <src rangeset> <tgt rangeset>
398      #
399      # version 2:
400      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
401      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
402      #   move <tgt rangeset> <src_str>
403      #
404      # version 3:
405      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
406      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
407      #   move hash <tgt rangeset> <src_str>
408
409      tgt_size = xf.tgt_ranges.size()
410
411      if xf.style == "new":
412        assert xf.tgt_ranges
413        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
414        total += tgt_size
415      elif xf.style == "move":
416        performs_read = True
417        assert xf.tgt_ranges
418        assert xf.src_ranges.size() == tgt_size
419        if xf.src_ranges != xf.tgt_ranges:
420          if self.version == 1:
421            out.append("%s %s %s\n" % (
422                xf.style,
423                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
424          elif self.version == 2:
425            out.append("%s %s %s\n" % (
426                xf.style,
427                xf.tgt_ranges.to_string_raw(), src_str))
428          elif self.version >= 3:
429            # take into account automatic stashing of overlapping blocks
430            if xf.src_ranges.overlaps(xf.tgt_ranges):
431              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
432              if temp_stash_usage > max_stashed_blocks:
433                max_stashed_blocks = temp_stash_usage
434
435            out.append("%s %s %s %s\n" % (
436                xf.style,
437                self.HashBlocks(self.tgt, xf.tgt_ranges),
438                xf.tgt_ranges.to_string_raw(), src_str))
439          total += tgt_size
440      elif xf.style in ("bsdiff", "imgdiff"):
441        performs_read = True
442        assert xf.tgt_ranges
443        assert xf.src_ranges
444        if self.version == 1:
445          out.append("%s %d %d %s %s\n" % (
446              xf.style, xf.patch_start, xf.patch_len,
447              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
448        elif self.version == 2:
449          out.append("%s %d %d %s %s\n" % (
450              xf.style, xf.patch_start, xf.patch_len,
451              xf.tgt_ranges.to_string_raw(), src_str))
452        elif self.version >= 3:
453          # take into account automatic stashing of overlapping blocks
454          if xf.src_ranges.overlaps(xf.tgt_ranges):
455            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
456            if temp_stash_usage > max_stashed_blocks:
457              max_stashed_blocks = temp_stash_usage
458
459          out.append("%s %d %d %s %s %s %s\n" % (
460              xf.style,
461              xf.patch_start, xf.patch_len,
462              self.HashBlocks(self.src, xf.src_ranges),
463              self.HashBlocks(self.tgt, xf.tgt_ranges),
464              xf.tgt_ranges.to_string_raw(), src_str))
465        total += tgt_size
466      elif xf.style == "zero":
467        assert xf.tgt_ranges
468        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
469        if to_zero:
470          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
471          total += to_zero.size()
472      else:
473        raise ValueError("unknown transfer style '%s'\n" % xf.style)
474
475      if free_string:
476        out.append("".join(free_string))
477
478      # sanity check: abort if we're going to need more than 512 MB if
479      # stash space
480      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
481
482    # Zero out extended blocks as a workaround for bug 20881595.
483    if self.tgt.extended:
484      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
485
486    # We erase all the blocks on the partition that a) don't contain useful
487    # data in the new image and b) will not be touched by dm-verity.
488    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
489    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
490    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
491    if new_dontcare:
492      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
493
494    out.insert(0, "%d\n" % (self.version,))   # format version number
495    out.insert(1, str(total) + "\n")
496    if self.version >= 2:
497      # version 2 only: after the total block count, we give the number
498      # of stash slots needed, and the maximum size needed (in blocks)
499      out.insert(2, str(next_stash_id) + "\n")
500      out.insert(3, str(max_stashed_blocks) + "\n")
501
502    with open(prefix + ".transfer.list", "wb") as f:
503      for i in out:
504        f.write(i)
505
506    if self.version >= 2:
507      print("max stashed blocks: %d  (%d bytes)\n" % (
508          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
509
510  def ComputePatches(self, prefix):
511    print("Reticulating splines...")
512    diff_q = []
513    patch_num = 0
514    with open(prefix + ".new.dat", "wb") as new_f:
515      for xf in self.transfers:
516        if xf.style == "zero":
517          pass
518        elif xf.style == "new":
519          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
520            new_f.write(piece)
521        elif xf.style == "diff":
522          src = self.src.ReadRangeSet(xf.src_ranges)
523          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
524
525          # We can't compare src and tgt directly because they may have
526          # the same content but be broken up into blocks differently, eg:
527          #
528          #    ["he", "llo"]  vs  ["h", "ello"]
529          #
530          # We want those to compare equal, ideally without having to
531          # actually concatenate the strings (these may be tens of
532          # megabytes).
533
534          src_sha1 = sha1()
535          for p in src:
536            src_sha1.update(p)
537          tgt_sha1 = sha1()
538          tgt_size = 0
539          for p in tgt:
540            tgt_sha1.update(p)
541            tgt_size += len(p)
542
543          if src_sha1.digest() == tgt_sha1.digest():
544            # These are identical; we don't need to generate a patch,
545            # just issue copy commands on the device.
546            xf.style = "move"
547          else:
548            # For files in zip format (eg, APKs, JARs, etc.) we would
549            # like to use imgdiff -z if possible (because it usually
550            # produces significantly smaller patches than bsdiff).
551            # This is permissible if:
552            #
553            #  - the source and target files are monotonic (ie, the
554            #    data is stored with blocks in increasing order), and
555            #  - we haven't removed any blocks from the source set.
556            #
557            # If these conditions are satisfied then appending all the
558            # blocks in the set together in order will produce a valid
559            # zip file (plus possibly extra zeros in the last block),
560            # which is what imgdiff needs to operate.  (imgdiff is
561            # fine with extra zeros at the end of the file.)
562            imgdiff = (xf.intact and
563                       xf.tgt_name.split(".")[-1].lower()
564                       in ("apk", "jar", "zip"))
565            xf.style = "imgdiff" if imgdiff else "bsdiff"
566            diff_q.append((tgt_size, src, tgt, xf, patch_num))
567            patch_num += 1
568
569        else:
570          assert False, "unknown style " + xf.style
571
572    if diff_q:
573      if self.threads > 1:
574        print("Computing patches (using %d threads)..." % (self.threads,))
575      else:
576        print("Computing patches...")
577      diff_q.sort()
578
579      patches = [None] * patch_num
580
581      # TODO: Rewrite with multiprocessing.ThreadPool?
582      lock = threading.Lock()
583      def diff_worker():
584        while True:
585          with lock:
586            if not diff_q:
587              return
588            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
589          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
590          size = len(patch)
591          with lock:
592            patches[patchnum] = (patch, xf)
593            print("%10d %10d (%6.2f%%) %7s %s" % (
594                size, tgt_size, size * 100.0 / tgt_size, xf.style,
595                xf.tgt_name if xf.tgt_name == xf.src_name else (
596                    xf.tgt_name + " (from " + xf.src_name + ")")))
597
598      threads = [threading.Thread(target=diff_worker)
599                 for _ in range(self.threads)]
600      for th in threads:
601        th.start()
602      while threads:
603        threads.pop().join()
604    else:
605      patches = []
606
607    p = 0
608    with open(prefix + ".patch.dat", "wb") as patch_f:
609      for patch, xf in patches:
610        xf.patch_start = p
611        xf.patch_len = len(patch)
612        patch_f.write(patch)
613        p += len(patch)
614
615  def AssertSequenceGood(self):
616    # Simulate the sequences of transfers we will output, and check that:
617    # - we never read a block after writing it, and
618    # - we write every block we care about exactly once.
619
620    # Start with no blocks having been touched yet.
621    touched = RangeSet()
622
623    # Imagine processing the transfers in order.
624    for xf in self.transfers:
625      # Check that the input blocks for this transfer haven't yet been touched.
626
627      x = xf.src_ranges
628      if self.version >= 2:
629        for _, sr in xf.use_stash:
630          x = x.subtract(sr)
631
632      assert not touched.overlaps(x)
633      # Check that the output blocks for this transfer haven't yet been touched.
634      assert not touched.overlaps(xf.tgt_ranges)
635      # Touch all the blocks written by this transfer.
636      touched = touched.union(xf.tgt_ranges)
637
638    # Check that we've written every target block.
639    assert touched == self.tgt.care_map
640
641  def ImproveVertexSequence(self):
642    print("Improving vertex order...")
643
644    # At this point our digraph is acyclic; we reversed any edges that
645    # were backwards in the heuristically-generated sequence.  The
646    # previously-generated order is still acceptable, but we hope to
647    # find a better order that needs less memory for stashed data.
648    # Now we do a topological sort to generate a new vertex order,
649    # using a greedy algorithm to choose which vertex goes next
650    # whenever we have a choice.
651
652    # Make a copy of the edge set; this copy will get destroyed by the
653    # algorithm.
654    for xf in self.transfers:
655      xf.incoming = xf.goes_after.copy()
656      xf.outgoing = xf.goes_before.copy()
657
658    L = []   # the new vertex order
659
660    # S is the set of sources in the remaining graph; we always choose
661    # the one that leaves the least amount of stashed data after it's
662    # executed.
663    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
664         if not u.incoming]
665    heapq.heapify(S)
666
667    while S:
668      _, _, xf = heapq.heappop(S)
669      L.append(xf)
670      for u in xf.outgoing:
671        del u.incoming[xf]
672        if not u.incoming:
673          heapq.heappush(S, (u.NetStashChange(), u.order, u))
674
675    # if this fails then our graph had a cycle.
676    assert len(L) == len(self.transfers)
677
678    self.transfers = L
679    for i, xf in enumerate(L):
680      xf.order = i
681
682  def RemoveBackwardEdges(self):
683    print("Removing backward edges...")
684    in_order = 0
685    out_of_order = 0
686    lost_source = 0
687
688    for xf in self.transfers:
689      lost = 0
690      size = xf.src_ranges.size()
691      for u in xf.goes_before:
692        # xf should go before u
693        if xf.order < u.order:
694          # it does, hurray!
695          in_order += 1
696        else:
697          # it doesn't, boo.  trim the blocks that u writes from xf's
698          # source, so that xf can go after u.
699          out_of_order += 1
700          assert xf.src_ranges.overlaps(u.tgt_ranges)
701          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
702          xf.intact = False
703
704      if xf.style == "diff" and not xf.src_ranges:
705        # nothing left to diff from; treat as new data
706        xf.style = "new"
707
708      lost = size - xf.src_ranges.size()
709      lost_source += lost
710
711    print(("  %d/%d dependencies (%.2f%%) were violated; "
712           "%d source blocks removed.") %
713          (out_of_order, in_order + out_of_order,
714           (out_of_order * 100.0 / (in_order + out_of_order))
715           if (in_order + out_of_order) else 0.0,
716           lost_source))
717
718  def ReverseBackwardEdges(self):
719    print("Reversing backward edges...")
720    in_order = 0
721    out_of_order = 0
722    stashes = 0
723    stash_size = 0
724
725    for xf in self.transfers:
726      for u in xf.goes_before.copy():
727        # xf should go before u
728        if xf.order < u.order:
729          # it does, hurray!
730          in_order += 1
731        else:
732          # it doesn't, boo.  modify u to stash the blocks that it
733          # writes that xf wants to read, and then require u to go
734          # before xf.
735          out_of_order += 1
736
737          overlap = xf.src_ranges.intersect(u.tgt_ranges)
738          assert overlap
739
740          u.stash_before.append((stashes, overlap))
741          xf.use_stash.append((stashes, overlap))
742          stashes += 1
743          stash_size += overlap.size()
744
745          # reverse the edge direction; now xf must go after u
746          del xf.goes_before[u]
747          del u.goes_after[xf]
748          xf.goes_after[u] = None    # value doesn't matter
749          u.goes_before[xf] = None
750
751    print(("  %d/%d dependencies (%.2f%%) were violated; "
752           "%d source blocks stashed.") %
753          (out_of_order, in_order + out_of_order,
754           (out_of_order * 100.0 / (in_order + out_of_order))
755           if (in_order + out_of_order) else 0.0,
756           stash_size))
757
758  def FindVertexSequence(self):
759    print("Finding vertex sequence...")
760
761    # This is based on "A Fast & Effective Heuristic for the Feedback
762    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
763    # it as starting with the digraph G and moving all the vertices to
764    # be on a horizontal line in some order, trying to minimize the
765    # number of edges that end up pointing to the left.  Left-pointing
766    # edges will get removed to turn the digraph into a DAG.  In this
767    # case each edge has a weight which is the number of source blocks
768    # we'll lose if that edge is removed; we try to minimize the total
769    # weight rather than just the number of edges.
770
771    # Make a copy of the edge set; this copy will get destroyed by the
772    # algorithm.
773    for xf in self.transfers:
774      xf.incoming = xf.goes_after.copy()
775      xf.outgoing = xf.goes_before.copy()
776
777    # We use an OrderedDict instead of just a set so that the output
778    # is repeatable; otherwise it would depend on the hash values of
779    # the transfer objects.
780    G = OrderedDict()
781    for xf in self.transfers:
782      G[xf] = None
783    s1 = deque()  # the left side of the sequence, built from left to right
784    s2 = deque()  # the right side of the sequence, built from right to left
785
786    while G:
787
788      # Put all sinks at the end of the sequence.
789      while True:
790        sinks = [u for u in G if not u.outgoing]
791        if not sinks:
792          break
793        for u in sinks:
794          s2.appendleft(u)
795          del G[u]
796          for iu in u.incoming:
797            del iu.outgoing[u]
798
799      # Put all the sources at the beginning of the sequence.
800      while True:
801        sources = [u for u in G if not u.incoming]
802        if not sources:
803          break
804        for u in sources:
805          s1.append(u)
806          del G[u]
807          for iu in u.outgoing:
808            del iu.incoming[u]
809
810      if not G:
811        break
812
813      # Find the "best" vertex to put next.  "Best" is the one that
814      # maximizes the net difference in source blocks saved we get by
815      # pretending it's a source rather than a sink.
816
817      max_d = None
818      best_u = None
819      for u in G:
820        d = sum(u.outgoing.values()) - sum(u.incoming.values())
821        if best_u is None or d > max_d:
822          max_d = d
823          best_u = u
824
825      u = best_u
826      s1.append(u)
827      del G[u]
828      for iu in u.outgoing:
829        del iu.incoming[u]
830      for iu in u.incoming:
831        del iu.outgoing[u]
832
833    # Now record the sequence in the 'order' field of each transfer,
834    # and by rearranging self.transfers to be in the chosen sequence.
835
836    new_transfers = []
837    for x in itertools.chain(s1, s2):
838      x.order = len(new_transfers)
839      new_transfers.append(x)
840      del x.incoming
841      del x.outgoing
842
843    self.transfers = new_transfers
844
845  def GenerateDigraph(self):
846    print("Generating digraph...")
847    for a in self.transfers:
848      for b in self.transfers:
849        if a is b:
850          continue
851
852        # If the blocks written by A are read by B, then B needs to go before A.
853        i = a.tgt_ranges.intersect(b.src_ranges)
854        if i:
855          if b.src_name == "__ZERO":
856            # the cost of removing source blocks for the __ZERO domain
857            # is (nearly) zero.
858            size = 0
859          else:
860            size = i.size()
861          b.goes_before[a] = size
862          a.goes_after[b] = size
863
864  def FindTransfers(self):
865    empty = RangeSet()
866    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
867      if tgt_fn == "__ZERO":
868        # the special "__ZERO" domain is all the blocks not contained
869        # in any file and that are filled with zeros.  We have a
870        # special transfer style for zero blocks.
871        src_ranges = self.src.file_map.get("__ZERO", empty)
872        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
873                 "zero", self.transfers)
874        continue
875
876      elif tgt_fn == "__COPY":
877        # "__COPY" domain includes all the blocks not contained in any
878        # file and that need to be copied unconditionally to the target.
879        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
880        continue
881
882      elif tgt_fn in self.src.file_map:
883        # Look for an exact pathname match in the source.
884        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
885                 "diff", self.transfers)
886        continue
887
888      b = os.path.basename(tgt_fn)
889      if b in self.src_basenames:
890        # Look for an exact basename match in the source.
891        src_fn = self.src_basenames[b]
892        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
893                 "diff", self.transfers)
894        continue
895
896      b = re.sub("[0-9]+", "#", b)
897      if b in self.src_numpatterns:
898        # Look for a 'number pattern' match (a basename match after
899        # all runs of digits are replaced by "#").  (This is useful
900        # for .so files that contain version numbers in the filename
901        # that get bumped.)
902        src_fn = self.src_numpatterns[b]
903        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
904                 "diff", self.transfers)
905        continue
906
907      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
908
909  def AbbreviateSourceNames(self):
910    for k in self.src.file_map.keys():
911      b = os.path.basename(k)
912      self.src_basenames[b] = k
913      b = re.sub("[0-9]+", "#", b)
914      self.src_numpatterns[b] = k
915
916  @staticmethod
917  def AssertPartition(total, seq):
918    """Assert that all the RangeSets in 'seq' form a partition of the
919    'total' RangeSet (ie, they are nonintersecting and their union
920    equals 'total')."""
921    so_far = RangeSet()
922    for i in seq:
923      assert not so_far.overlaps(i)
924      so_far = so_far.union(i)
925    assert so_far == total
926