blockimgdiff.py revision d47d8e14880132c42a75f41c8041851797c75e35
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import common
20import heapq
21import itertools
22import multiprocessing
23import os
24import re
25import subprocess
26import threading
27import tempfile
28
29from rangelib import RangeSet
30
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34
35def compute_patch(src, tgt, imgdiff=False):
36  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
37  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
38  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
39  os.close(patchfd)
40
41  try:
42    with os.fdopen(srcfd, "wb") as f_src:
43      for p in src:
44        f_src.write(p)
45
46    with os.fdopen(tgtfd, "wb") as f_tgt:
47      for p in tgt:
48        f_tgt.write(p)
49    try:
50      os.unlink(patchfile)
51    except OSError:
52      pass
53    if imgdiff:
54      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
55                          stdout=open("/dev/null", "a"),
56                          stderr=subprocess.STDOUT)
57    else:
58      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
59
60    if p:
61      raise ValueError("diff failed: " + str(p))
62
63    with open(patchfile, "rb") as f:
64      return f.read()
65  finally:
66    try:
67      os.unlink(srcfile)
68      os.unlink(tgtfile)
69      os.unlink(patchfile)
70    except OSError:
71      pass
72
73
74class Image(object):
75  def ReadRangeSet(self, ranges):
76    raise NotImplementedError
77
78  def TotalSha1(self, include_clobbered_blocks=False):
79    raise NotImplementedError
80
81
82class EmptyImage(Image):
83  """A zero-length image."""
84  blocksize = 4096
85  care_map = RangeSet()
86  clobbered_blocks = RangeSet()
87  extended = RangeSet()
88  total_blocks = 0
89  file_map = {}
90  def ReadRangeSet(self, ranges):
91    return ()
92  def TotalSha1(self, include_clobbered_blocks=False):
93    # EmptyImage always carries empty clobbered_blocks, so
94    # include_clobbered_blocks can be ignored.
95    assert self.clobbered_blocks.size() == 0
96    return sha1().hexdigest()
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    if partial > 0:
110      if trim:
111        self.data = self.data[:-partial]
112      elif pad:
113        self.data += '\0' * (self.blocksize - partial)
114      else:
115        raise ValueError(("data for DataImage must be multiple of %d bytes "
116                          "unless trim or pad is specified") %
117                         (self.blocksize,))
118
119    assert len(self.data) % self.blocksize == 0
120
121    self.total_blocks = len(self.data) / self.blocksize
122    self.care_map = RangeSet(data=(0, self.total_blocks))
123    self.clobbered_blocks = RangeSet()
124    self.extended = RangeSet()
125
126    zero_blocks = []
127    nonzero_blocks = []
128    reference = '\0' * self.blocksize
129
130    for i in range(self.total_blocks):
131      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
132      if d == reference:
133        zero_blocks.append(i)
134        zero_blocks.append(i+1)
135      else:
136        nonzero_blocks.append(i)
137        nonzero_blocks.append(i+1)
138
139    self.file_map = {"__ZERO": RangeSet(zero_blocks),
140                     "__NONZERO": RangeSet(nonzero_blocks)}
141
142  def ReadRangeSet(self, ranges):
143    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
144
145  def TotalSha1(self, include_clobbered_blocks=False):
146    # DataImage always carries empty clobbered_blocks, so
147    # include_clobbered_blocks can be ignored.
148    assert self.clobbered_blocks.size() == 0
149    return sha1(self.data).hexdigest()
150
151
152class Transfer(object):
153  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
154    self.tgt_name = tgt_name
155    self.src_name = src_name
156    self.tgt_ranges = tgt_ranges
157    self.src_ranges = src_ranges
158    self.style = style
159    self.intact = (getattr(tgt_ranges, "monotonic", False) and
160                   getattr(src_ranges, "monotonic", False))
161
162    # We use OrderedDict rather than dict so that the output is repeatable;
163    # otherwise it would depend on the hash values of the Transfer objects.
164    self.goes_before = OrderedDict()
165    self.goes_after = OrderedDict()
166
167    self.stash_before = []
168    self.use_stash = []
169
170    self.id = len(by_id)
171    by_id.append(self)
172
173  def NetStashChange(self):
174    return (sum(sr.size() for (_, sr) in self.stash_before) -
175            sum(sr.size() for (_, sr) in self.use_stash))
176
177  def __str__(self):
178    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
179            " to " + str(self.tgt_ranges) + ">")
180
181
182# BlockImageDiff works on two image objects.  An image object is
183# anything that provides the following attributes:
184#
185#    blocksize: the size in bytes of a block, currently must be 4096.
186#
187#    total_blocks: the total size of the partition/image, in blocks.
188#
189#    care_map: a RangeSet containing which blocks (in the range [0,
190#      total_blocks) we actually care about; i.e. which blocks contain
191#      data.
192#
193#    file_map: a dict that partitions the blocks contained in care_map
194#      into smaller domains that are useful for doing diffs on.
195#      (Typically a domain is a file, and the key in file_map is the
196#      pathname.)
197#
198#    clobbered_blocks: a RangeSet containing which blocks contain data
199#      but may be altered by the FS. They need to be excluded when
200#      verifying the partition integrity.
201#
202#    ReadRangeSet(): a function that takes a RangeSet and returns the
203#      data contained in the image blocks of that RangeSet.  The data
204#      is returned as a list or tuple of strings; concatenating the
205#      elements together should produce the requested data.
206#      Implementations are free to break up the data into list/tuple
207#      elements in any way that is convenient.
208#
209#    TotalSha1(): a function that returns (as a hex string) the SHA-1
210#      hash of all the data in the image (ie, all the blocks in the
211#      care_map minus clobbered_blocks, or including the clobbered
212#      blocks if include_clobbered_blocks is True).
213#
214# When creating a BlockImageDiff, the src image may be None, in which
215# case the list of transfers produced will never read from the
216# original image.
217
218class BlockImageDiff(object):
219  def __init__(self, tgt, src=None, threads=None, version=3):
220    if threads is None:
221      threads = multiprocessing.cpu_count() // 2
222      if threads == 0:
223        threads = 1
224    self.threads = threads
225    self.version = version
226    self.transfers = []
227    self.src_basenames = {}
228    self.src_numpatterns = {}
229
230    assert version in (1, 2, 3)
231
232    self.tgt = tgt
233    if src is None:
234      src = EmptyImage()
235    self.src = src
236
237    # The updater code that installs the patch always uses 4k blocks.
238    assert tgt.blocksize == 4096
239    assert src.blocksize == 4096
240
241    # The range sets in each filemap should comprise a partition of
242    # the care map.
243    self.AssertPartition(src.care_map, src.file_map.values())
244    self.AssertPartition(tgt.care_map, tgt.file_map.values())
245
246  def Compute(self, prefix):
247    # When looking for a source file to use as the diff input for a
248    # target file, we try:
249    #   1) an exact path match if available, otherwise
250    #   2) a exact basename match if available, otherwise
251    #   3) a basename match after all runs of digits are replaced by
252    #      "#" if available, otherwise
253    #   4) we have no source for this target.
254    self.AbbreviateSourceNames()
255    self.FindTransfers()
256
257    # Find the ordering dependencies among transfers (this is O(n^2)
258    # in the number of transfers).
259    self.GenerateDigraph()
260    # Find a sequence of transfers that satisfies as many ordering
261    # dependencies as possible (heuristically).
262    self.FindVertexSequence()
263    # Fix up the ordering dependencies that the sequence didn't
264    # satisfy.
265    if self.version == 1:
266      self.RemoveBackwardEdges()
267    else:
268      self.ReverseBackwardEdges()
269      self.ImproveVertexSequence()
270
271    # Double-check our work.
272    self.AssertSequenceGood()
273
274    self.ComputePatches(prefix)
275    self.WriteTransfers(prefix)
276
277  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
278    data = source.ReadRangeSet(ranges)
279    ctx = sha1()
280
281    for p in data:
282      ctx.update(p)
283
284    return ctx.hexdigest()
285
286  def WriteTransfers(self, prefix):
287    out = []
288
289    total = 0
290    performs_read = False
291
292    stashes = {}
293    stashed_blocks = 0
294    max_stashed_blocks = 0
295
296    free_stash_ids = []
297    next_stash_id = 0
298
299    for xf in self.transfers:
300
301      if self.version < 2:
302        assert not xf.stash_before
303        assert not xf.use_stash
304
305      for s, sr in xf.stash_before:
306        assert s not in stashes
307        if free_stash_ids:
308          sid = heapq.heappop(free_stash_ids)
309        else:
310          sid = next_stash_id
311          next_stash_id += 1
312        stashes[s] = sid
313        stashed_blocks += sr.size()
314        if self.version == 2:
315          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
316        else:
317          sh = self.HashBlocks(self.src, sr)
318          if sh in stashes:
319            stashes[sh] += 1
320          else:
321            stashes[sh] = 1
322            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
323
324      if stashed_blocks > max_stashed_blocks:
325        max_stashed_blocks = stashed_blocks
326
327      free_string = []
328
329      if self.version == 1:
330        src_str = xf.src_ranges.to_string_raw()
331      elif self.version >= 2:
332
333        #   <# blocks> <src ranges>
334        #     OR
335        #   <# blocks> <src ranges> <src locs> <stash refs...>
336        #     OR
337        #   <# blocks> - <stash refs...>
338
339        size = xf.src_ranges.size()
340        src_str = [str(size)]
341
342        unstashed_src_ranges = xf.src_ranges
343        mapped_stashes = []
344        for s, sr in xf.use_stash:
345          sid = stashes.pop(s)
346          stashed_blocks -= sr.size()
347          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
348          sh = self.HashBlocks(self.src, sr)
349          sr = xf.src_ranges.map_within(sr)
350          mapped_stashes.append(sr)
351          if self.version == 2:
352            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
353          else:
354            assert sh in stashes
355            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
356            stashes[sh] -= 1
357            if stashes[sh] == 0:
358              free_string.append("free %s\n" % (sh))
359              stashes.pop(sh)
360          heapq.heappush(free_stash_ids, sid)
361
362        if unstashed_src_ranges:
363          src_str.insert(1, unstashed_src_ranges.to_string_raw())
364          if xf.use_stash:
365            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
366            src_str.insert(2, mapped_unstashed.to_string_raw())
367            mapped_stashes.append(mapped_unstashed)
368            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
369        else:
370          src_str.insert(1, "-")
371          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
372
373        src_str = " ".join(src_str)
374
375      # all versions:
376      #   zero <rangeset>
377      #   new <rangeset>
378      #   erase <rangeset>
379      #
380      # version 1:
381      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
382      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
383      #   move <src rangeset> <tgt rangeset>
384      #
385      # version 2:
386      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
387      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
388      #   move <tgt rangeset> <src_str>
389      #
390      # version 3:
391      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
392      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
393      #   move hash <tgt rangeset> <src_str>
394
395      tgt_size = xf.tgt_ranges.size()
396
397      if xf.style == "new":
398        assert xf.tgt_ranges
399        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
400        total += tgt_size
401      elif xf.style == "move":
402        performs_read = True
403        assert xf.tgt_ranges
404        assert xf.src_ranges.size() == tgt_size
405        if xf.src_ranges != xf.tgt_ranges:
406          if self.version == 1:
407            out.append("%s %s %s\n" % (
408                xf.style,
409                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
410          elif self.version == 2:
411            out.append("%s %s %s\n" % (
412                xf.style,
413                xf.tgt_ranges.to_string_raw(), src_str))
414          elif self.version >= 3:
415            # take into account automatic stashing of overlapping blocks
416            if xf.src_ranges.overlaps(xf.tgt_ranges):
417              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
418              if temp_stash_usage > max_stashed_blocks:
419                max_stashed_blocks = temp_stash_usage
420
421            out.append("%s %s %s %s\n" % (
422                xf.style,
423                self.HashBlocks(self.tgt, xf.tgt_ranges),
424                xf.tgt_ranges.to_string_raw(), src_str))
425          total += tgt_size
426      elif xf.style in ("bsdiff", "imgdiff"):
427        performs_read = True
428        assert xf.tgt_ranges
429        assert xf.src_ranges
430        if self.version == 1:
431          out.append("%s %d %d %s %s\n" % (
432              xf.style, xf.patch_start, xf.patch_len,
433              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
434        elif self.version == 2:
435          out.append("%s %d %d %s %s\n" % (
436              xf.style, xf.patch_start, xf.patch_len,
437              xf.tgt_ranges.to_string_raw(), src_str))
438        elif self.version >= 3:
439          # take into account automatic stashing of overlapping blocks
440          if xf.src_ranges.overlaps(xf.tgt_ranges):
441            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
442            if temp_stash_usage > max_stashed_blocks:
443              max_stashed_blocks = temp_stash_usage
444
445          out.append("%s %d %d %s %s %s %s\n" % (
446              xf.style,
447              xf.patch_start, xf.patch_len,
448              self.HashBlocks(self.src, xf.src_ranges),
449              self.HashBlocks(self.tgt, xf.tgt_ranges),
450              xf.tgt_ranges.to_string_raw(), src_str))
451        total += tgt_size
452      elif xf.style == "zero":
453        assert xf.tgt_ranges
454        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
455        if to_zero:
456          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
457          total += to_zero.size()
458      else:
459        raise ValueError("unknown transfer style '%s'\n" % xf.style)
460
461      if free_string:
462        out.append("".join(free_string))
463
464      if self.version >= 2:
465        # Sanity check: abort if we're going to need more stash space than
466        # the allowed size (cache_size * threshold). There are two purposes
467        # of having a threshold here. a) Part of the cache may have been
468        # occupied by some recovery logs. b) It will buy us some time to deal
469        # with the oversize issue.
470        cache_size = common.OPTIONS.cache_size
471        stash_threshold = common.OPTIONS.stash_threshold
472        max_allowed = cache_size * stash_threshold
473        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
474               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
475                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
476                   self.tgt.blocksize, max_allowed, cache_size,
477                   stash_threshold)
478
479    # Zero out extended blocks as a workaround for bug 20881595.
480    if self.tgt.extended:
481      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
482
483    # We erase all the blocks on the partition that a) don't contain useful
484    # data in the new image and b) will not be touched by dm-verity.
485    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
486    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
487    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
488    if new_dontcare:
489      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
490
491    out.insert(0, "%d\n" % (self.version,))   # format version number
492    out.insert(1, str(total) + "\n")
493    if self.version >= 2:
494      # version 2 only: after the total block count, we give the number
495      # of stash slots needed, and the maximum size needed (in blocks)
496      out.insert(2, str(next_stash_id) + "\n")
497      out.insert(3, str(max_stashed_blocks) + "\n")
498
499    with open(prefix + ".transfer.list", "wb") as f:
500      for i in out:
501        f.write(i)
502
503    if self.version >= 2:
504      max_stashed_size = max_stashed_blocks * self.tgt.blocksize
505      max_allowed = common.OPTIONS.cache_size * common.OPTIONS.stash_threshold
506      print("max stashed blocks: %d  (%d bytes), limit: %d bytes (%.2f%%)\n" % (
507          max_stashed_blocks, max_stashed_size, max_allowed,
508          max_stashed_size * 100.0 / max_allowed))
509
510  def ComputePatches(self, prefix):
511    print("Reticulating splines...")
512    diff_q = []
513    patch_num = 0
514    with open(prefix + ".new.dat", "wb") as new_f:
515      for xf in self.transfers:
516        if xf.style == "zero":
517          pass
518        elif xf.style == "new":
519          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
520            new_f.write(piece)
521        elif xf.style == "diff":
522          src = self.src.ReadRangeSet(xf.src_ranges)
523          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
524
525          # We can't compare src and tgt directly because they may have
526          # the same content but be broken up into blocks differently, eg:
527          #
528          #    ["he", "llo"]  vs  ["h", "ello"]
529          #
530          # We want those to compare equal, ideally without having to
531          # actually concatenate the strings (these may be tens of
532          # megabytes).
533
534          src_sha1 = sha1()
535          for p in src:
536            src_sha1.update(p)
537          tgt_sha1 = sha1()
538          tgt_size = 0
539          for p in tgt:
540            tgt_sha1.update(p)
541            tgt_size += len(p)
542
543          if src_sha1.digest() == tgt_sha1.digest():
544            # These are identical; we don't need to generate a patch,
545            # just issue copy commands on the device.
546            xf.style = "move"
547          else:
548            # For files in zip format (eg, APKs, JARs, etc.) we would
549            # like to use imgdiff -z if possible (because it usually
550            # produces significantly smaller patches than bsdiff).
551            # This is permissible if:
552            #
553            #  - the source and target files are monotonic (ie, the
554            #    data is stored with blocks in increasing order), and
555            #  - we haven't removed any blocks from the source set.
556            #
557            # If these conditions are satisfied then appending all the
558            # blocks in the set together in order will produce a valid
559            # zip file (plus possibly extra zeros in the last block),
560            # which is what imgdiff needs to operate.  (imgdiff is
561            # fine with extra zeros at the end of the file.)
562            imgdiff = (xf.intact and
563                       xf.tgt_name.split(".")[-1].lower()
564                       in ("apk", "jar", "zip"))
565            xf.style = "imgdiff" if imgdiff else "bsdiff"
566            diff_q.append((tgt_size, src, tgt, xf, patch_num))
567            patch_num += 1
568
569        else:
570          assert False, "unknown style " + xf.style
571
572    if diff_q:
573      if self.threads > 1:
574        print("Computing patches (using %d threads)..." % (self.threads,))
575      else:
576        print("Computing patches...")
577      diff_q.sort()
578
579      patches = [None] * patch_num
580
581      # TODO: Rewrite with multiprocessing.ThreadPool?
582      lock = threading.Lock()
583      def diff_worker():
584        while True:
585          with lock:
586            if not diff_q:
587              return
588            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
589          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
590          size = len(patch)
591          with lock:
592            patches[patchnum] = (patch, xf)
593            print("%10d %10d (%6.2f%%) %7s %s" % (
594                size, tgt_size, size * 100.0 / tgt_size, xf.style,
595                xf.tgt_name if xf.tgt_name == xf.src_name else (
596                    xf.tgt_name + " (from " + xf.src_name + ")")))
597
598      threads = [threading.Thread(target=diff_worker)
599                 for _ in range(self.threads)]
600      for th in threads:
601        th.start()
602      while threads:
603        threads.pop().join()
604    else:
605      patches = []
606
607    p = 0
608    with open(prefix + ".patch.dat", "wb") as patch_f:
609      for patch, xf in patches:
610        xf.patch_start = p
611        xf.patch_len = len(patch)
612        patch_f.write(patch)
613        p += len(patch)
614
615  def AssertSequenceGood(self):
616    # Simulate the sequences of transfers we will output, and check that:
617    # - we never read a block after writing it, and
618    # - we write every block we care about exactly once.
619
620    # Start with no blocks having been touched yet.
621    touched = RangeSet()
622
623    # Imagine processing the transfers in order.
624    for xf in self.transfers:
625      # Check that the input blocks for this transfer haven't yet been touched.
626
627      x = xf.src_ranges
628      if self.version >= 2:
629        for _, sr in xf.use_stash:
630          x = x.subtract(sr)
631
632      assert not touched.overlaps(x)
633      # Check that the output blocks for this transfer haven't yet been touched.
634      assert not touched.overlaps(xf.tgt_ranges)
635      # Touch all the blocks written by this transfer.
636      touched = touched.union(xf.tgt_ranges)
637
638    # Check that we've written every target block.
639    assert touched == self.tgt.care_map
640
641  def ImproveVertexSequence(self):
642    print("Improving vertex order...")
643
644    # At this point our digraph is acyclic; we reversed any edges that
645    # were backwards in the heuristically-generated sequence.  The
646    # previously-generated order is still acceptable, but we hope to
647    # find a better order that needs less memory for stashed data.
648    # Now we do a topological sort to generate a new vertex order,
649    # using a greedy algorithm to choose which vertex goes next
650    # whenever we have a choice.
651
652    # Make a copy of the edge set; this copy will get destroyed by the
653    # algorithm.
654    for xf in self.transfers:
655      xf.incoming = xf.goes_after.copy()
656      xf.outgoing = xf.goes_before.copy()
657
658    L = []   # the new vertex order
659
660    # S is the set of sources in the remaining graph; we always choose
661    # the one that leaves the least amount of stashed data after it's
662    # executed.
663    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
664         if not u.incoming]
665    heapq.heapify(S)
666
667    while S:
668      _, _, xf = heapq.heappop(S)
669      L.append(xf)
670      for u in xf.outgoing:
671        del u.incoming[xf]
672        if not u.incoming:
673          heapq.heappush(S, (u.NetStashChange(), u.order, u))
674
675    # if this fails then our graph had a cycle.
676    assert len(L) == len(self.transfers)
677
678    self.transfers = L
679    for i, xf in enumerate(L):
680      xf.order = i
681
682  def RemoveBackwardEdges(self):
683    print("Removing backward edges...")
684    in_order = 0
685    out_of_order = 0
686    lost_source = 0
687
688    for xf in self.transfers:
689      lost = 0
690      size = xf.src_ranges.size()
691      for u in xf.goes_before:
692        # xf should go before u
693        if xf.order < u.order:
694          # it does, hurray!
695          in_order += 1
696        else:
697          # it doesn't, boo.  trim the blocks that u writes from xf's
698          # source, so that xf can go after u.
699          out_of_order += 1
700          assert xf.src_ranges.overlaps(u.tgt_ranges)
701          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
702          xf.intact = False
703
704      if xf.style == "diff" and not xf.src_ranges:
705        # nothing left to diff from; treat as new data
706        xf.style = "new"
707
708      lost = size - xf.src_ranges.size()
709      lost_source += lost
710
711    print(("  %d/%d dependencies (%.2f%%) were violated; "
712           "%d source blocks removed.") %
713          (out_of_order, in_order + out_of_order,
714           (out_of_order * 100.0 / (in_order + out_of_order))
715           if (in_order + out_of_order) else 0.0,
716           lost_source))
717
718  def ReverseBackwardEdges(self):
719    print("Reversing backward edges...")
720    in_order = 0
721    out_of_order = 0
722    stashes = 0
723    stash_size = 0
724
725    for xf in self.transfers:
726      for u in xf.goes_before.copy():
727        # xf should go before u
728        if xf.order < u.order:
729          # it does, hurray!
730          in_order += 1
731        else:
732          # it doesn't, boo.  modify u to stash the blocks that it
733          # writes that xf wants to read, and then require u to go
734          # before xf.
735          out_of_order += 1
736
737          overlap = xf.src_ranges.intersect(u.tgt_ranges)
738          assert overlap
739
740          u.stash_before.append((stashes, overlap))
741          xf.use_stash.append((stashes, overlap))
742          stashes += 1
743          stash_size += overlap.size()
744
745          # reverse the edge direction; now xf must go after u
746          del xf.goes_before[u]
747          del u.goes_after[xf]
748          xf.goes_after[u] = None    # value doesn't matter
749          u.goes_before[xf] = None
750
751    print(("  %d/%d dependencies (%.2f%%) were violated; "
752           "%d source blocks stashed.") %
753          (out_of_order, in_order + out_of_order,
754           (out_of_order * 100.0 / (in_order + out_of_order))
755           if (in_order + out_of_order) else 0.0,
756           stash_size))
757
758  def FindVertexSequence(self):
759    print("Finding vertex sequence...")
760
761    # This is based on "A Fast & Effective Heuristic for the Feedback
762    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
763    # it as starting with the digraph G and moving all the vertices to
764    # be on a horizontal line in some order, trying to minimize the
765    # number of edges that end up pointing to the left.  Left-pointing
766    # edges will get removed to turn the digraph into a DAG.  In this
767    # case each edge has a weight which is the number of source blocks
768    # we'll lose if that edge is removed; we try to minimize the total
769    # weight rather than just the number of edges.
770
771    # Make a copy of the edge set; this copy will get destroyed by the
772    # algorithm.
773    for xf in self.transfers:
774      xf.incoming = xf.goes_after.copy()
775      xf.outgoing = xf.goes_before.copy()
776
777    # We use an OrderedDict instead of just a set so that the output
778    # is repeatable; otherwise it would depend on the hash values of
779    # the transfer objects.
780    G = OrderedDict()
781    for xf in self.transfers:
782      G[xf] = None
783    s1 = deque()  # the left side of the sequence, built from left to right
784    s2 = deque()  # the right side of the sequence, built from right to left
785
786    while G:
787
788      # Put all sinks at the end of the sequence.
789      while True:
790        sinks = [u for u in G if not u.outgoing]
791        if not sinks:
792          break
793        for u in sinks:
794          s2.appendleft(u)
795          del G[u]
796          for iu in u.incoming:
797            del iu.outgoing[u]
798
799      # Put all the sources at the beginning of the sequence.
800      while True:
801        sources = [u for u in G if not u.incoming]
802        if not sources:
803          break
804        for u in sources:
805          s1.append(u)
806          del G[u]
807          for iu in u.outgoing:
808            del iu.incoming[u]
809
810      if not G:
811        break
812
813      # Find the "best" vertex to put next.  "Best" is the one that
814      # maximizes the net difference in source blocks saved we get by
815      # pretending it's a source rather than a sink.
816
817      max_d = None
818      best_u = None
819      for u in G:
820        d = sum(u.outgoing.values()) - sum(u.incoming.values())
821        if best_u is None or d > max_d:
822          max_d = d
823          best_u = u
824
825      u = best_u
826      s1.append(u)
827      del G[u]
828      for iu in u.outgoing:
829        del iu.incoming[u]
830      for iu in u.incoming:
831        del iu.outgoing[u]
832
833    # Now record the sequence in the 'order' field of each transfer,
834    # and by rearranging self.transfers to be in the chosen sequence.
835
836    new_transfers = []
837    for x in itertools.chain(s1, s2):
838      x.order = len(new_transfers)
839      new_transfers.append(x)
840      del x.incoming
841      del x.outgoing
842
843    self.transfers = new_transfers
844
845  def GenerateDigraph(self):
846    print("Generating digraph...")
847    for a in self.transfers:
848      for b in self.transfers:
849        if a is b:
850          continue
851
852        # If the blocks written by A are read by B, then B needs to go before A.
853        i = a.tgt_ranges.intersect(b.src_ranges)
854        if i:
855          if b.src_name == "__ZERO":
856            # the cost of removing source blocks for the __ZERO domain
857            # is (nearly) zero.
858            size = 0
859          else:
860            size = i.size()
861          b.goes_before[a] = size
862          a.goes_after[b] = size
863
864  def FindTransfers(self):
865    empty = RangeSet()
866    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
867      if tgt_fn == "__ZERO":
868        # the special "__ZERO" domain is all the blocks not contained
869        # in any file and that are filled with zeros.  We have a
870        # special transfer style for zero blocks.
871        src_ranges = self.src.file_map.get("__ZERO", empty)
872        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
873                 "zero", self.transfers)
874        continue
875
876      elif tgt_fn == "__COPY":
877        # "__COPY" domain includes all the blocks not contained in any
878        # file and that need to be copied unconditionally to the target.
879        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
880        continue
881
882      elif tgt_fn in self.src.file_map:
883        # Look for an exact pathname match in the source.
884        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
885                 "diff", self.transfers)
886        continue
887
888      b = os.path.basename(tgt_fn)
889      if b in self.src_basenames:
890        # Look for an exact basename match in the source.
891        src_fn = self.src_basenames[b]
892        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
893                 "diff", self.transfers)
894        continue
895
896      b = re.sub("[0-9]+", "#", b)
897      if b in self.src_numpatterns:
898        # Look for a 'number pattern' match (a basename match after
899        # all runs of digits are replaced by "#").  (This is useful
900        # for .so files that contain version numbers in the filename
901        # that get bumped.)
902        src_fn = self.src_numpatterns[b]
903        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
904                 "diff", self.transfers)
905        continue
906
907      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
908
909  def AbbreviateSourceNames(self):
910    for k in self.src.file_map.keys():
911      b = os.path.basename(k)
912      self.src_basenames[b] = k
913      b = re.sub("[0-9]+", "#", b)
914      self.src_numpatterns[b] = k
915
916  @staticmethod
917  def AssertPartition(total, seq):
918    """Assert that all the RangeSets in 'seq' form a partition of the
919    'total' RangeSet (ie, they are nonintersecting and their union
920    equals 'total')."""
921    so_far = RangeSet()
922    for i in seq:
923      assert not so_far.overlaps(i)
924      so_far = so_far.union(i)
925    assert so_far == total
926