blockimgdiff.py revision 575d68a48edc90d655509f2980dacc69958948de
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import common
20import heapq
21import itertools
22import multiprocessing
23import os
24import re
25import subprocess
26import threading
27import tempfile
28
29from rangelib import RangeSet
30
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34
35def compute_patch(src, tgt, imgdiff=False):
36  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
37  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
38  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
39  os.close(patchfd)
40
41  try:
42    with os.fdopen(srcfd, "wb") as f_src:
43      for p in src:
44        f_src.write(p)
45
46    with os.fdopen(tgtfd, "wb") as f_tgt:
47      for p in tgt:
48        f_tgt.write(p)
49    try:
50      os.unlink(patchfile)
51    except OSError:
52      pass
53    if imgdiff:
54      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
55                          stdout=open("/dev/null", "a"),
56                          stderr=subprocess.STDOUT)
57    else:
58      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
59
60    if p:
61      raise ValueError("diff failed: " + str(p))
62
63    with open(patchfile, "rb") as f:
64      return f.read()
65  finally:
66    try:
67      os.unlink(srcfile)
68      os.unlink(tgtfile)
69      os.unlink(patchfile)
70    except OSError:
71      pass
72
73
74class Image(object):
75  def ReadRangeSet(self, ranges):
76    raise NotImplementedError
77
78  def TotalSha1(self, include_clobbered_blocks=False):
79    raise NotImplementedError
80
81
82class EmptyImage(Image):
83  """A zero-length image."""
84  blocksize = 4096
85  care_map = RangeSet()
86  clobbered_blocks = RangeSet()
87  extended = RangeSet()
88  total_blocks = 0
89  file_map = {}
90  def ReadRangeSet(self, ranges):
91    return ()
92  def TotalSha1(self, include_clobbered_blocks=False):
93    # EmptyImage always carries empty clobbered_blocks, so
94    # include_clobbered_blocks can be ignored.
95    assert self.clobbered_blocks.size() == 0
96    return sha1().hexdigest()
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    if partial > 0:
110      if trim:
111        self.data = self.data[:-partial]
112      elif pad:
113        self.data += '\0' * (self.blocksize - partial)
114      else:
115        raise ValueError(("data for DataImage must be multiple of %d bytes "
116                          "unless trim or pad is specified") %
117                         (self.blocksize,))
118
119    assert len(self.data) % self.blocksize == 0
120
121    self.total_blocks = len(self.data) / self.blocksize
122    self.care_map = RangeSet(data=(0, self.total_blocks))
123    self.clobbered_blocks = RangeSet()
124    self.extended = RangeSet()
125
126    zero_blocks = []
127    nonzero_blocks = []
128    reference = '\0' * self.blocksize
129
130    for i in range(self.total_blocks):
131      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
132      if d == reference:
133        zero_blocks.append(i)
134        zero_blocks.append(i+1)
135      else:
136        nonzero_blocks.append(i)
137        nonzero_blocks.append(i+1)
138
139    self.file_map = {"__ZERO": RangeSet(zero_blocks),
140                     "__NONZERO": RangeSet(nonzero_blocks)}
141
142  def ReadRangeSet(self, ranges):
143    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
144
145  def TotalSha1(self, include_clobbered_blocks=False):
146    # DataImage always carries empty clobbered_blocks, so
147    # include_clobbered_blocks can be ignored.
148    assert self.clobbered_blocks.size() == 0
149    return sha1(self.data).hexdigest()
150
151
152class Transfer(object):
153  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
154    self.tgt_name = tgt_name
155    self.src_name = src_name
156    self.tgt_ranges = tgt_ranges
157    self.src_ranges = src_ranges
158    self.style = style
159    self.intact = (getattr(tgt_ranges, "monotonic", False) and
160                   getattr(src_ranges, "monotonic", False))
161
162    # We use OrderedDict rather than dict so that the output is repeatable;
163    # otherwise it would depend on the hash values of the Transfer objects.
164    self.goes_before = OrderedDict()
165    self.goes_after = OrderedDict()
166
167    self.stash_before = []
168    self.use_stash = []
169
170    self.id = len(by_id)
171    by_id.append(self)
172
173  def NetStashChange(self):
174    return (sum(sr.size() for (_, sr) in self.stash_before) -
175            sum(sr.size() for (_, sr) in self.use_stash))
176
177  def __str__(self):
178    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
179            " to " + str(self.tgt_ranges) + ">")
180
181
182# BlockImageDiff works on two image objects.  An image object is
183# anything that provides the following attributes:
184#
185#    blocksize: the size in bytes of a block, currently must be 4096.
186#
187#    total_blocks: the total size of the partition/image, in blocks.
188#
189#    care_map: a RangeSet containing which blocks (in the range [0,
190#      total_blocks) we actually care about; i.e. which blocks contain
191#      data.
192#
193#    file_map: a dict that partitions the blocks contained in care_map
194#      into smaller domains that are useful for doing diffs on.
195#      (Typically a domain is a file, and the key in file_map is the
196#      pathname.)
197#
198#    clobbered_blocks: a RangeSet containing which blocks contain data
199#      but may be altered by the FS. They need to be excluded when
200#      verifying the partition integrity.
201#
202#    ReadRangeSet(): a function that takes a RangeSet and returns the
203#      data contained in the image blocks of that RangeSet.  The data
204#      is returned as a list or tuple of strings; concatenating the
205#      elements together should produce the requested data.
206#      Implementations are free to break up the data into list/tuple
207#      elements in any way that is convenient.
208#
209#    TotalSha1(): a function that returns (as a hex string) the SHA-1
210#      hash of all the data in the image (ie, all the blocks in the
211#      care_map minus clobbered_blocks, or including the clobbered
212#      blocks if include_clobbered_blocks is True).
213#
214# When creating a BlockImageDiff, the src image may be None, in which
215# case the list of transfers produced will never read from the
216# original image.
217
218class BlockImageDiff(object):
219  def __init__(self, tgt, src=None, threads=None, version=3):
220    if threads is None:
221      threads = multiprocessing.cpu_count() // 2
222      if threads == 0:
223        threads = 1
224    self.threads = threads
225    self.version = version
226    self.transfers = []
227    self.src_basenames = {}
228    self.src_numpatterns = {}
229
230    assert version in (1, 2, 3)
231
232    self.tgt = tgt
233    if src is None:
234      src = EmptyImage()
235    self.src = src
236
237    # The updater code that installs the patch always uses 4k blocks.
238    assert tgt.blocksize == 4096
239    assert src.blocksize == 4096
240
241    # The range sets in each filemap should comprise a partition of
242    # the care map.
243    self.AssertPartition(src.care_map, src.file_map.values())
244    self.AssertPartition(tgt.care_map, tgt.file_map.values())
245
246  def Compute(self, prefix):
247    # When looking for a source file to use as the diff input for a
248    # target file, we try:
249    #   1) an exact path match if available, otherwise
250    #   2) a exact basename match if available, otherwise
251    #   3) a basename match after all runs of digits are replaced by
252    #      "#" if available, otherwise
253    #   4) we have no source for this target.
254    self.AbbreviateSourceNames()
255    self.FindTransfers()
256
257    # Find the ordering dependencies among transfers (this is O(n^2)
258    # in the number of transfers).
259    self.GenerateDigraph()
260    # Find a sequence of transfers that satisfies as many ordering
261    # dependencies as possible (heuristically).
262    self.FindVertexSequence()
263    # Fix up the ordering dependencies that the sequence didn't
264    # satisfy.
265    if self.version == 1:
266      self.RemoveBackwardEdges()
267    else:
268      self.ReverseBackwardEdges()
269      self.ImproveVertexSequence()
270
271    # Double-check our work.
272    self.AssertSequenceGood()
273
274    self.ComputePatches(prefix)
275    self.WriteTransfers(prefix)
276
277  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
278    data = source.ReadRangeSet(ranges)
279    ctx = sha1()
280
281    for p in data:
282      ctx.update(p)
283
284    return ctx.hexdigest()
285
286  def WriteTransfers(self, prefix):
287    out = []
288
289    total = 0
290    performs_read = False
291
292    stashes = {}
293    stashed_blocks = 0
294    max_stashed_blocks = 0
295
296    free_stash_ids = []
297    next_stash_id = 0
298
299    for xf in self.transfers:
300
301      if self.version < 2:
302        assert not xf.stash_before
303        assert not xf.use_stash
304
305      for s, sr in xf.stash_before:
306        assert s not in stashes
307        if free_stash_ids:
308          sid = heapq.heappop(free_stash_ids)
309        else:
310          sid = next_stash_id
311          next_stash_id += 1
312        stashes[s] = sid
313        stashed_blocks += sr.size()
314        if self.version == 2:
315          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
316        else:
317          sh = self.HashBlocks(self.src, sr)
318          if sh in stashes:
319            stashes[sh] += 1
320          else:
321            stashes[sh] = 1
322            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
323
324      if stashed_blocks > max_stashed_blocks:
325        max_stashed_blocks = stashed_blocks
326
327      free_string = []
328
329      if self.version == 1:
330        src_str = xf.src_ranges.to_string_raw()
331      elif self.version >= 2:
332
333        #   <# blocks> <src ranges>
334        #     OR
335        #   <# blocks> <src ranges> <src locs> <stash refs...>
336        #     OR
337        #   <# blocks> - <stash refs...>
338
339        size = xf.src_ranges.size()
340        src_str = [str(size)]
341
342        unstashed_src_ranges = xf.src_ranges
343        mapped_stashes = []
344        for s, sr in xf.use_stash:
345          sid = stashes.pop(s)
346          stashed_blocks -= sr.size()
347          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
348          sh = self.HashBlocks(self.src, sr)
349          sr = xf.src_ranges.map_within(sr)
350          mapped_stashes.append(sr)
351          if self.version == 2:
352            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
353          else:
354            assert sh in stashes
355            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
356            stashes[sh] -= 1
357            if stashes[sh] == 0:
358              free_string.append("free %s\n" % (sh))
359              stashes.pop(sh)
360          heapq.heappush(free_stash_ids, sid)
361
362        if unstashed_src_ranges:
363          src_str.insert(1, unstashed_src_ranges.to_string_raw())
364          if xf.use_stash:
365            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
366            src_str.insert(2, mapped_unstashed.to_string_raw())
367            mapped_stashes.append(mapped_unstashed)
368            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
369        else:
370          src_str.insert(1, "-")
371          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
372
373        src_str = " ".join(src_str)
374
375      # all versions:
376      #   zero <rangeset>
377      #   new <rangeset>
378      #   erase <rangeset>
379      #
380      # version 1:
381      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
382      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
383      #   move <src rangeset> <tgt rangeset>
384      #
385      # version 2:
386      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
387      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
388      #   move <tgt rangeset> <src_str>
389      #
390      # version 3:
391      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
392      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
393      #   move hash <tgt rangeset> <src_str>
394
395      tgt_size = xf.tgt_ranges.size()
396
397      if xf.style == "new":
398        assert xf.tgt_ranges
399        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
400        total += tgt_size
401      elif xf.style == "move":
402        performs_read = True
403        assert xf.tgt_ranges
404        assert xf.src_ranges.size() == tgt_size
405        if xf.src_ranges != xf.tgt_ranges:
406          if self.version == 1:
407            out.append("%s %s %s\n" % (
408                xf.style,
409                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
410          elif self.version == 2:
411            out.append("%s %s %s\n" % (
412                xf.style,
413                xf.tgt_ranges.to_string_raw(), src_str))
414          elif self.version >= 3:
415            # take into account automatic stashing of overlapping blocks
416            if xf.src_ranges.overlaps(xf.tgt_ranges):
417              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
418              if temp_stash_usage > max_stashed_blocks:
419                max_stashed_blocks = temp_stash_usage
420
421            out.append("%s %s %s %s\n" % (
422                xf.style,
423                self.HashBlocks(self.tgt, xf.tgt_ranges),
424                xf.tgt_ranges.to_string_raw(), src_str))
425          total += tgt_size
426      elif xf.style in ("bsdiff", "imgdiff"):
427        performs_read = True
428        assert xf.tgt_ranges
429        assert xf.src_ranges
430        if self.version == 1:
431          out.append("%s %d %d %s %s\n" % (
432              xf.style, xf.patch_start, xf.patch_len,
433              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
434        elif self.version == 2:
435          out.append("%s %d %d %s %s\n" % (
436              xf.style, xf.patch_start, xf.patch_len,
437              xf.tgt_ranges.to_string_raw(), src_str))
438        elif self.version >= 3:
439          # take into account automatic stashing of overlapping blocks
440          if xf.src_ranges.overlaps(xf.tgt_ranges):
441            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
442            if temp_stash_usage > max_stashed_blocks:
443              max_stashed_blocks = temp_stash_usage
444
445          out.append("%s %d %d %s %s %s %s\n" % (
446              xf.style,
447              xf.patch_start, xf.patch_len,
448              self.HashBlocks(self.src, xf.src_ranges),
449              self.HashBlocks(self.tgt, xf.tgt_ranges),
450              xf.tgt_ranges.to_string_raw(), src_str))
451        total += tgt_size
452      elif xf.style == "zero":
453        assert xf.tgt_ranges
454        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
455        if to_zero:
456          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
457          total += to_zero.size()
458      else:
459        raise ValueError("unknown transfer style '%s'\n" % xf.style)
460
461      if free_string:
462        out.append("".join(free_string))
463
464      if self.version >= 2 and common.OPTIONS.cache_size is not None:
465        # Sanity check: abort if we're going to need more stash space than
466        # the allowed size (cache_size * threshold). There are two purposes
467        # of having a threshold here. a) Part of the cache may have been
468        # occupied by some recovery logs. b) It will buy us some time to deal
469        # with the oversize issue.
470        cache_size = common.OPTIONS.cache_size
471        stash_threshold = common.OPTIONS.stash_threshold
472        max_allowed = cache_size * stash_threshold
473        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
474               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
475                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
476                   self.tgt.blocksize, max_allowed, cache_size,
477                   stash_threshold)
478
479    # Zero out extended blocks as a workaround for bug 20881595.
480    if self.tgt.extended:
481      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
482
483    # We erase all the blocks on the partition that a) don't contain useful
484    # data in the new image and b) will not be touched by dm-verity.
485    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
486    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
487    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
488    if new_dontcare:
489      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
490
491    out.insert(0, "%d\n" % (self.version,))   # format version number
492    out.insert(1, str(total) + "\n")
493    if self.version >= 2:
494      # version 2 only: after the total block count, we give the number
495      # of stash slots needed, and the maximum size needed (in blocks)
496      out.insert(2, str(next_stash_id) + "\n")
497      out.insert(3, str(max_stashed_blocks) + "\n")
498
499    with open(prefix + ".transfer.list", "wb") as f:
500      for i in out:
501        f.write(i)
502
503    if self.version >= 2:
504      max_stashed_size = max_stashed_blocks * self.tgt.blocksize
505      OPTIONS = common.OPTIONS
506      if OPTIONS.cache_size is not None:
507        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
508        print("max stashed blocks: %d  (%d bytes), "
509              "limit: %d bytes (%.2f%%)\n" % (
510              max_stashed_blocks, max_stashed_size, max_allowed,
511              max_stashed_size * 100.0 / max_allowed))
512      else:
513        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
514              max_stashed_blocks, max_stashed_size))
515
516  def ComputePatches(self, prefix):
517    print("Reticulating splines...")
518    diff_q = []
519    patch_num = 0
520    with open(prefix + ".new.dat", "wb") as new_f:
521      for xf in self.transfers:
522        if xf.style == "zero":
523          pass
524        elif xf.style == "new":
525          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
526            new_f.write(piece)
527        elif xf.style == "diff":
528          src = self.src.ReadRangeSet(xf.src_ranges)
529          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
530
531          # We can't compare src and tgt directly because they may have
532          # the same content but be broken up into blocks differently, eg:
533          #
534          #    ["he", "llo"]  vs  ["h", "ello"]
535          #
536          # We want those to compare equal, ideally without having to
537          # actually concatenate the strings (these may be tens of
538          # megabytes).
539
540          src_sha1 = sha1()
541          for p in src:
542            src_sha1.update(p)
543          tgt_sha1 = sha1()
544          tgt_size = 0
545          for p in tgt:
546            tgt_sha1.update(p)
547            tgt_size += len(p)
548
549          if src_sha1.digest() == tgt_sha1.digest():
550            # These are identical; we don't need to generate a patch,
551            # just issue copy commands on the device.
552            xf.style = "move"
553          else:
554            # For files in zip format (eg, APKs, JARs, etc.) we would
555            # like to use imgdiff -z if possible (because it usually
556            # produces significantly smaller patches than bsdiff).
557            # This is permissible if:
558            #
559            #  - the source and target files are monotonic (ie, the
560            #    data is stored with blocks in increasing order), and
561            #  - we haven't removed any blocks from the source set.
562            #
563            # If these conditions are satisfied then appending all the
564            # blocks in the set together in order will produce a valid
565            # zip file (plus possibly extra zeros in the last block),
566            # which is what imgdiff needs to operate.  (imgdiff is
567            # fine with extra zeros at the end of the file.)
568            imgdiff = (xf.intact and
569                       xf.tgt_name.split(".")[-1].lower()
570                       in ("apk", "jar", "zip"))
571            xf.style = "imgdiff" if imgdiff else "bsdiff"
572            diff_q.append((tgt_size, src, tgt, xf, patch_num))
573            patch_num += 1
574
575        else:
576          assert False, "unknown style " + xf.style
577
578    if diff_q:
579      if self.threads > 1:
580        print("Computing patches (using %d threads)..." % (self.threads,))
581      else:
582        print("Computing patches...")
583      diff_q.sort()
584
585      patches = [None] * patch_num
586
587      # TODO: Rewrite with multiprocessing.ThreadPool?
588      lock = threading.Lock()
589      def diff_worker():
590        while True:
591          with lock:
592            if not diff_q:
593              return
594            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
595          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
596          size = len(patch)
597          with lock:
598            patches[patchnum] = (patch, xf)
599            print("%10d %10d (%6.2f%%) %7s %s" % (
600                size, tgt_size, size * 100.0 / tgt_size, xf.style,
601                xf.tgt_name if xf.tgt_name == xf.src_name else (
602                    xf.tgt_name + " (from " + xf.src_name + ")")))
603
604      threads = [threading.Thread(target=diff_worker)
605                 for _ in range(self.threads)]
606      for th in threads:
607        th.start()
608      while threads:
609        threads.pop().join()
610    else:
611      patches = []
612
613    p = 0
614    with open(prefix + ".patch.dat", "wb") as patch_f:
615      for patch, xf in patches:
616        xf.patch_start = p
617        xf.patch_len = len(patch)
618        patch_f.write(patch)
619        p += len(patch)
620
621  def AssertSequenceGood(self):
622    # Simulate the sequences of transfers we will output, and check that:
623    # - we never read a block after writing it, and
624    # - we write every block we care about exactly once.
625
626    # Start with no blocks having been touched yet.
627    touched = RangeSet()
628
629    # Imagine processing the transfers in order.
630    for xf in self.transfers:
631      # Check that the input blocks for this transfer haven't yet been touched.
632
633      x = xf.src_ranges
634      if self.version >= 2:
635        for _, sr in xf.use_stash:
636          x = x.subtract(sr)
637
638      assert not touched.overlaps(x)
639      # Check that the output blocks for this transfer haven't yet been touched.
640      assert not touched.overlaps(xf.tgt_ranges)
641      # Touch all the blocks written by this transfer.
642      touched = touched.union(xf.tgt_ranges)
643
644    # Check that we've written every target block.
645    assert touched == self.tgt.care_map
646
647  def ImproveVertexSequence(self):
648    print("Improving vertex order...")
649
650    # At this point our digraph is acyclic; we reversed any edges that
651    # were backwards in the heuristically-generated sequence.  The
652    # previously-generated order is still acceptable, but we hope to
653    # find a better order that needs less memory for stashed data.
654    # Now we do a topological sort to generate a new vertex order,
655    # using a greedy algorithm to choose which vertex goes next
656    # whenever we have a choice.
657
658    # Make a copy of the edge set; this copy will get destroyed by the
659    # algorithm.
660    for xf in self.transfers:
661      xf.incoming = xf.goes_after.copy()
662      xf.outgoing = xf.goes_before.copy()
663
664    L = []   # the new vertex order
665
666    # S is the set of sources in the remaining graph; we always choose
667    # the one that leaves the least amount of stashed data after it's
668    # executed.
669    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
670         if not u.incoming]
671    heapq.heapify(S)
672
673    while S:
674      _, _, xf = heapq.heappop(S)
675      L.append(xf)
676      for u in xf.outgoing:
677        del u.incoming[xf]
678        if not u.incoming:
679          heapq.heappush(S, (u.NetStashChange(), u.order, u))
680
681    # if this fails then our graph had a cycle.
682    assert len(L) == len(self.transfers)
683
684    self.transfers = L
685    for i, xf in enumerate(L):
686      xf.order = i
687
688  def RemoveBackwardEdges(self):
689    print("Removing backward edges...")
690    in_order = 0
691    out_of_order = 0
692    lost_source = 0
693
694    for xf in self.transfers:
695      lost = 0
696      size = xf.src_ranges.size()
697      for u in xf.goes_before:
698        # xf should go before u
699        if xf.order < u.order:
700          # it does, hurray!
701          in_order += 1
702        else:
703          # it doesn't, boo.  trim the blocks that u writes from xf's
704          # source, so that xf can go after u.
705          out_of_order += 1
706          assert xf.src_ranges.overlaps(u.tgt_ranges)
707          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
708          xf.intact = False
709
710      if xf.style == "diff" and not xf.src_ranges:
711        # nothing left to diff from; treat as new data
712        xf.style = "new"
713
714      lost = size - xf.src_ranges.size()
715      lost_source += lost
716
717    print(("  %d/%d dependencies (%.2f%%) were violated; "
718           "%d source blocks removed.") %
719          (out_of_order, in_order + out_of_order,
720           (out_of_order * 100.0 / (in_order + out_of_order))
721           if (in_order + out_of_order) else 0.0,
722           lost_source))
723
724  def ReverseBackwardEdges(self):
725    print("Reversing backward edges...")
726    in_order = 0
727    out_of_order = 0
728    stashes = 0
729    stash_size = 0
730
731    for xf in self.transfers:
732      for u in xf.goes_before.copy():
733        # xf should go before u
734        if xf.order < u.order:
735          # it does, hurray!
736          in_order += 1
737        else:
738          # it doesn't, boo.  modify u to stash the blocks that it
739          # writes that xf wants to read, and then require u to go
740          # before xf.
741          out_of_order += 1
742
743          overlap = xf.src_ranges.intersect(u.tgt_ranges)
744          assert overlap
745
746          u.stash_before.append((stashes, overlap))
747          xf.use_stash.append((stashes, overlap))
748          stashes += 1
749          stash_size += overlap.size()
750
751          # reverse the edge direction; now xf must go after u
752          del xf.goes_before[u]
753          del u.goes_after[xf]
754          xf.goes_after[u] = None    # value doesn't matter
755          u.goes_before[xf] = None
756
757    print(("  %d/%d dependencies (%.2f%%) were violated; "
758           "%d source blocks stashed.") %
759          (out_of_order, in_order + out_of_order,
760           (out_of_order * 100.0 / (in_order + out_of_order))
761           if (in_order + out_of_order) else 0.0,
762           stash_size))
763
764  def FindVertexSequence(self):
765    print("Finding vertex sequence...")
766
767    # This is based on "A Fast & Effective Heuristic for the Feedback
768    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
769    # it as starting with the digraph G and moving all the vertices to
770    # be on a horizontal line in some order, trying to minimize the
771    # number of edges that end up pointing to the left.  Left-pointing
772    # edges will get removed to turn the digraph into a DAG.  In this
773    # case each edge has a weight which is the number of source blocks
774    # we'll lose if that edge is removed; we try to minimize the total
775    # weight rather than just the number of edges.
776
777    # Make a copy of the edge set; this copy will get destroyed by the
778    # algorithm.
779    for xf in self.transfers:
780      xf.incoming = xf.goes_after.copy()
781      xf.outgoing = xf.goes_before.copy()
782
783    # We use an OrderedDict instead of just a set so that the output
784    # is repeatable; otherwise it would depend on the hash values of
785    # the transfer objects.
786    G = OrderedDict()
787    for xf in self.transfers:
788      G[xf] = None
789    s1 = deque()  # the left side of the sequence, built from left to right
790    s2 = deque()  # the right side of the sequence, built from right to left
791
792    while G:
793
794      # Put all sinks at the end of the sequence.
795      while True:
796        sinks = [u for u in G if not u.outgoing]
797        if not sinks:
798          break
799        for u in sinks:
800          s2.appendleft(u)
801          del G[u]
802          for iu in u.incoming:
803            del iu.outgoing[u]
804
805      # Put all the sources at the beginning of the sequence.
806      while True:
807        sources = [u for u in G if not u.incoming]
808        if not sources:
809          break
810        for u in sources:
811          s1.append(u)
812          del G[u]
813          for iu in u.outgoing:
814            del iu.incoming[u]
815
816      if not G:
817        break
818
819      # Find the "best" vertex to put next.  "Best" is the one that
820      # maximizes the net difference in source blocks saved we get by
821      # pretending it's a source rather than a sink.
822
823      max_d = None
824      best_u = None
825      for u in G:
826        d = sum(u.outgoing.values()) - sum(u.incoming.values())
827        if best_u is None or d > max_d:
828          max_d = d
829          best_u = u
830
831      u = best_u
832      s1.append(u)
833      del G[u]
834      for iu in u.outgoing:
835        del iu.incoming[u]
836      for iu in u.incoming:
837        del iu.outgoing[u]
838
839    # Now record the sequence in the 'order' field of each transfer,
840    # and by rearranging self.transfers to be in the chosen sequence.
841
842    new_transfers = []
843    for x in itertools.chain(s1, s2):
844      x.order = len(new_transfers)
845      new_transfers.append(x)
846      del x.incoming
847      del x.outgoing
848
849    self.transfers = new_transfers
850
851  def GenerateDigraph(self):
852    print("Generating digraph...")
853    for a in self.transfers:
854      for b in self.transfers:
855        if a is b:
856          continue
857
858        # If the blocks written by A are read by B, then B needs to go before A.
859        i = a.tgt_ranges.intersect(b.src_ranges)
860        if i:
861          if b.src_name == "__ZERO":
862            # the cost of removing source blocks for the __ZERO domain
863            # is (nearly) zero.
864            size = 0
865          else:
866            size = i.size()
867          b.goes_before[a] = size
868          a.goes_after[b] = size
869
870  def FindTransfers(self):
871    empty = RangeSet()
872    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
873      if tgt_fn == "__ZERO":
874        # the special "__ZERO" domain is all the blocks not contained
875        # in any file and that are filled with zeros.  We have a
876        # special transfer style for zero blocks.
877        src_ranges = self.src.file_map.get("__ZERO", empty)
878        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
879                 "zero", self.transfers)
880        continue
881
882      elif tgt_fn == "__COPY":
883        # "__COPY" domain includes all the blocks not contained in any
884        # file and that need to be copied unconditionally to the target.
885        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
886        continue
887
888      elif tgt_fn in self.src.file_map:
889        # Look for an exact pathname match in the source.
890        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
891                 "diff", self.transfers)
892        continue
893
894      b = os.path.basename(tgt_fn)
895      if b in self.src_basenames:
896        # Look for an exact basename match in the source.
897        src_fn = self.src_basenames[b]
898        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
899                 "diff", self.transfers)
900        continue
901
902      b = re.sub("[0-9]+", "#", b)
903      if b in self.src_numpatterns:
904        # Look for a 'number pattern' match (a basename match after
905        # all runs of digits are replaced by "#").  (This is useful
906        # for .so files that contain version numbers in the filename
907        # that get bumped.)
908        src_fn = self.src_numpatterns[b]
909        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
910                 "diff", self.transfers)
911        continue
912
913      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
914
915  def AbbreviateSourceNames(self):
916    for k in self.src.file_map.keys():
917      b = os.path.basename(k)
918      self.src_basenames[b] = k
919      b = re.sub("[0-9]+", "#", b)
920      self.src_numpatterns[b] = k
921
922  @staticmethod
923  def AssertPartition(total, seq):
924    """Assert that all the RangeSets in 'seq' form a partition of the
925    'total' RangeSet (ie, they are nonintersecting and their union
926    equals 'total')."""
927    so_far = RangeSet()
928    for i in seq:
929      assert not so_far.overlaps(i)
930      so_far = so_far.union(i)
931    assert so_far == total
932