blockimgdiff.py revision 5ece99d64efe82fc1ca97a96f079ce69ea588a24
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import re
24import subprocess
25import threading
26import tempfile
27
28from rangelib import RangeSet
29
30
31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72
73class Image(object):
74  def ReadRangeSet(self, ranges):
75    raise NotImplementedError
76
77  def TotalSha1(self):
78    raise NotImplementedError
79
80
81class EmptyImage(Image):
82  """A zero-length image."""
83  blocksize = 4096
84  care_map = RangeSet()
85  clobbered_blocks = RangeSet()
86  total_blocks = 0
87  file_map = {}
88  def ReadRangeSet(self, ranges):
89    return ()
90  def TotalSha1(self):
91    return sha1().hexdigest()
92
93
94class DataImage(Image):
95  """An image wrapped around a single string of data."""
96
97  def __init__(self, data, trim=False, pad=False):
98    self.data = data
99    self.blocksize = 4096
100
101    assert not (trim and pad)
102
103    partial = len(self.data) % self.blocksize
104    if partial > 0:
105      if trim:
106        self.data = self.data[:-partial]
107      elif pad:
108        self.data += '\0' * (self.blocksize - partial)
109      else:
110        raise ValueError(("data for DataImage must be multiple of %d bytes "
111                          "unless trim or pad is specified") %
112                         (self.blocksize,))
113
114    assert len(self.data) % self.blocksize == 0
115
116    self.total_blocks = len(self.data) / self.blocksize
117    self.care_map = RangeSet(data=(0, self.total_blocks))
118    self.clobbered_blocks = RangeSet()
119
120    zero_blocks = []
121    nonzero_blocks = []
122    reference = '\0' * self.blocksize
123
124    for i in range(self.total_blocks):
125      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
126      if d == reference:
127        zero_blocks.append(i)
128        zero_blocks.append(i+1)
129      else:
130        nonzero_blocks.append(i)
131        nonzero_blocks.append(i+1)
132
133    self.file_map = {"__ZERO": RangeSet(zero_blocks),
134                     "__NONZERO": RangeSet(nonzero_blocks)}
135
136  def ReadRangeSet(self, ranges):
137    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
138
139  def TotalSha1(self):
140    # DataImage always carries empty clobbered_blocks.
141    assert self.clobbered_blocks.size() == 0
142    return sha1(self.data).hexdigest()
143
144
145class Transfer(object):
146  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
147    self.tgt_name = tgt_name
148    self.src_name = src_name
149    self.tgt_ranges = tgt_ranges
150    self.src_ranges = src_ranges
151    self.style = style
152    self.intact = (getattr(tgt_ranges, "monotonic", False) and
153                   getattr(src_ranges, "monotonic", False))
154
155    # We use OrderedDict rather than dict so that the output is repeatable;
156    # otherwise it would depend on the hash values of the Transfer objects.
157    self.goes_before = OrderedDict()
158    self.goes_after = OrderedDict()
159
160    self.stash_before = []
161    self.use_stash = []
162
163    self.id = len(by_id)
164    by_id.append(self)
165
166  def NetStashChange(self):
167    return (sum(sr.size() for (_, sr) in self.stash_before) -
168            sum(sr.size() for (_, sr) in self.use_stash))
169
170  def __str__(self):
171    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
172            " to " + str(self.tgt_ranges) + ">")
173
174
175# BlockImageDiff works on two image objects.  An image object is
176# anything that provides the following attributes:
177#
178#    blocksize: the size in bytes of a block, currently must be 4096.
179#
180#    total_blocks: the total size of the partition/image, in blocks.
181#
182#    care_map: a RangeSet containing which blocks (in the range [0,
183#      total_blocks) we actually care about; i.e. which blocks contain
184#      data.
185#
186#    file_map: a dict that partitions the blocks contained in care_map
187#      into smaller domains that are useful for doing diffs on.
188#      (Typically a domain is a file, and the key in file_map is the
189#      pathname.)
190#
191#    clobbered_blocks: a RangeSet containing which blocks contain data
192#      but may be altered by the FS. They need to be excluded when
193#      verifying the partition integrity.
194#
195#    ReadRangeSet(): a function that takes a RangeSet and returns the
196#      data contained in the image blocks of that RangeSet.  The data
197#      is returned as a list or tuple of strings; concatenating the
198#      elements together should produce the requested data.
199#      Implementations are free to break up the data into list/tuple
200#      elements in any way that is convenient.
201#
202#    TotalSha1(): a function that returns (as a hex string) the SHA-1
203#      hash of all the data in the image (ie, all the blocks in the
204#      care_map minus clobbered_blocks).
205#
206# When creating a BlockImageDiff, the src image may be None, in which
207# case the list of transfers produced will never read from the
208# original image.
209
210class BlockImageDiff(object):
211  def __init__(self, tgt, src=None, threads=None, version=3):
212    if threads is None:
213      threads = multiprocessing.cpu_count() // 2
214      if threads == 0:
215        threads = 1
216    self.threads = threads
217    self.version = version
218    self.transfers = []
219    self.src_basenames = {}
220    self.src_numpatterns = {}
221
222    assert version in (1, 2, 3)
223
224    self.tgt = tgt
225    if src is None:
226      src = EmptyImage()
227    self.src = src
228
229    # The updater code that installs the patch always uses 4k blocks.
230    assert tgt.blocksize == 4096
231    assert src.blocksize == 4096
232
233    # The range sets in each filemap should comprise a partition of
234    # the care map.
235    self.AssertPartition(src.care_map, src.file_map.values())
236    self.AssertPartition(tgt.care_map, tgt.file_map.values())
237
238  def Compute(self, prefix):
239    # When looking for a source file to use as the diff input for a
240    # target file, we try:
241    #   1) an exact path match if available, otherwise
242    #   2) a exact basename match if available, otherwise
243    #   3) a basename match after all runs of digits are replaced by
244    #      "#" if available, otherwise
245    #   4) we have no source for this target.
246    self.AbbreviateSourceNames()
247    self.FindTransfers()
248
249    # Find the ordering dependencies among transfers (this is O(n^2)
250    # in the number of transfers).
251    self.GenerateDigraph()
252    # Find a sequence of transfers that satisfies as many ordering
253    # dependencies as possible (heuristically).
254    self.FindVertexSequence()
255    # Fix up the ordering dependencies that the sequence didn't
256    # satisfy.
257    if self.version == 1:
258      self.RemoveBackwardEdges()
259    else:
260      self.ReverseBackwardEdges()
261      self.ImproveVertexSequence()
262
263    # Double-check our work.
264    self.AssertSequenceGood()
265
266    self.ComputePatches(prefix)
267    self.WriteTransfers(prefix)
268
269  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
270    data = source.ReadRangeSet(ranges)
271    ctx = sha1()
272
273    for p in data:
274      ctx.update(p)
275
276    return ctx.hexdigest()
277
278  def WriteTransfers(self, prefix):
279    out = []
280
281    total = 0
282    performs_read = False
283
284    stashes = {}
285    stashed_blocks = 0
286    max_stashed_blocks = 0
287
288    free_stash_ids = []
289    next_stash_id = 0
290
291    for xf in self.transfers:
292
293      if self.version < 2:
294        assert not xf.stash_before
295        assert not xf.use_stash
296
297      for s, sr in xf.stash_before:
298        assert s not in stashes
299        if free_stash_ids:
300          sid = heapq.heappop(free_stash_ids)
301        else:
302          sid = next_stash_id
303          next_stash_id += 1
304        stashes[s] = sid
305        stashed_blocks += sr.size()
306        if self.version == 2:
307          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
308        else:
309          sh = self.HashBlocks(self.src, sr)
310          if sh in stashes:
311            stashes[sh] += 1
312          else:
313            stashes[sh] = 1
314            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
315
316      if stashed_blocks > max_stashed_blocks:
317        max_stashed_blocks = stashed_blocks
318
319      free_string = []
320
321      if self.version == 1:
322        src_str = xf.src_ranges.to_string_raw()
323      elif self.version >= 2:
324
325        #   <# blocks> <src ranges>
326        #     OR
327        #   <# blocks> <src ranges> <src locs> <stash refs...>
328        #     OR
329        #   <# blocks> - <stash refs...>
330
331        size = xf.src_ranges.size()
332        src_str = [str(size)]
333
334        unstashed_src_ranges = xf.src_ranges
335        mapped_stashes = []
336        for s, sr in xf.use_stash:
337          sid = stashes.pop(s)
338          stashed_blocks -= sr.size()
339          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
340          sh = self.HashBlocks(self.src, sr)
341          sr = xf.src_ranges.map_within(sr)
342          mapped_stashes.append(sr)
343          if self.version == 2:
344            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
345          else:
346            assert sh in stashes
347            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
348            stashes[sh] -= 1
349            if stashes[sh] == 0:
350              free_string.append("free %s\n" % (sh))
351              stashes.pop(sh)
352          heapq.heappush(free_stash_ids, sid)
353
354        if unstashed_src_ranges:
355          src_str.insert(1, unstashed_src_ranges.to_string_raw())
356          if xf.use_stash:
357            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
358            src_str.insert(2, mapped_unstashed.to_string_raw())
359            mapped_stashes.append(mapped_unstashed)
360            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
361        else:
362          src_str.insert(1, "-")
363          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
364
365        src_str = " ".join(src_str)
366
367      # all versions:
368      #   zero <rangeset>
369      #   new <rangeset>
370      #   erase <rangeset>
371      #
372      # version 1:
373      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
374      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
375      #   move <src rangeset> <tgt rangeset>
376      #
377      # version 2:
378      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
379      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
380      #   move <tgt rangeset> <src_str>
381      #
382      # version 3:
383      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
384      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
385      #   move hash <tgt rangeset> <src_str>
386
387      tgt_size = xf.tgt_ranges.size()
388
389      if xf.style == "new":
390        assert xf.tgt_ranges
391        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
392        total += tgt_size
393      elif xf.style == "move":
394        performs_read = True
395        assert xf.tgt_ranges
396        assert xf.src_ranges.size() == tgt_size
397        if xf.src_ranges != xf.tgt_ranges:
398          if self.version == 1:
399            out.append("%s %s %s\n" % (
400                xf.style,
401                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
402          elif self.version == 2:
403            out.append("%s %s %s\n" % (
404                xf.style,
405                xf.tgt_ranges.to_string_raw(), src_str))
406          elif self.version >= 3:
407            # take into account automatic stashing of overlapping blocks
408            if xf.src_ranges.overlaps(xf.tgt_ranges):
409              temp_stash_usage = stashed_blocks + xf.src_ranges.size();
410              if temp_stash_usage > max_stashed_blocks:
411                max_stashed_blocks = temp_stash_usage
412
413            out.append("%s %s %s %s\n" % (
414                xf.style,
415                self.HashBlocks(self.tgt, xf.tgt_ranges),
416                xf.tgt_ranges.to_string_raw(), src_str))
417          total += tgt_size
418      elif xf.style in ("bsdiff", "imgdiff"):
419        performs_read = True
420        assert xf.tgt_ranges
421        assert xf.src_ranges
422        if self.version == 1:
423          out.append("%s %d %d %s %s\n" % (
424              xf.style, xf.patch_start, xf.patch_len,
425              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
426        elif self.version == 2:
427          out.append("%s %d %d %s %s\n" % (
428              xf.style, xf.patch_start, xf.patch_len,
429              xf.tgt_ranges.to_string_raw(), src_str))
430        elif self.version >= 3:
431          # take into account automatic stashing of overlapping blocks
432          if xf.src_ranges.overlaps(xf.tgt_ranges):
433            temp_stash_usage = stashed_blocks + xf.src_ranges.size();
434            if temp_stash_usage > max_stashed_blocks:
435              max_stashed_blocks = temp_stash_usage
436
437          out.append("%s %d %d %s %s %s %s\n" % (
438              xf.style,
439              xf.patch_start, xf.patch_len,
440              self.HashBlocks(self.src, xf.src_ranges),
441              self.HashBlocks(self.tgt, xf.tgt_ranges),
442              xf.tgt_ranges.to_string_raw(), src_str))
443        total += tgt_size
444      elif xf.style == "zero":
445        assert xf.tgt_ranges
446        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
447        if to_zero:
448          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
449          total += to_zero.size()
450      else:
451        raise ValueError("unknown transfer style '%s'\n" % xf.style)
452
453      if free_string:
454        out.append("".join(free_string))
455
456      # sanity check: abort if we're going to need more than 512 MB if
457      # stash space
458      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
459
460    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
461    if performs_read:
462      # if some of the original data is used, then at the end we'll
463      # erase all the blocks on the partition that don't contain data
464      # in the new image.
465      new_dontcare = all_tgt.subtract(self.tgt.care_map)
466      if new_dontcare:
467        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
468    else:
469      # if nothing is read (ie, this is a full OTA), then we can start
470      # by erasing the entire partition.
471      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
472
473    out.insert(0, "%d\n" % (self.version,))   # format version number
474    out.insert(1, str(total) + "\n")
475    if self.version >= 2:
476      # version 2 only: after the total block count, we give the number
477      # of stash slots needed, and the maximum size needed (in blocks)
478      out.insert(2, str(next_stash_id) + "\n")
479      out.insert(3, str(max_stashed_blocks) + "\n")
480
481    with open(prefix + ".transfer.list", "wb") as f:
482      for i in out:
483        f.write(i)
484
485    if self.version >= 2:
486      print("max stashed blocks: %d  (%d bytes)\n" % (
487          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
488
489  def ComputePatches(self, prefix):
490    print("Reticulating splines...")
491    diff_q = []
492    patch_num = 0
493    with open(prefix + ".new.dat", "wb") as new_f:
494      for xf in self.transfers:
495        if xf.style == "zero":
496          pass
497        elif xf.style == "new":
498          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
499            new_f.write(piece)
500        elif xf.style == "diff":
501          src = self.src.ReadRangeSet(xf.src_ranges)
502          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
503
504          # We can't compare src and tgt directly because they may have
505          # the same content but be broken up into blocks differently, eg:
506          #
507          #    ["he", "llo"]  vs  ["h", "ello"]
508          #
509          # We want those to compare equal, ideally without having to
510          # actually concatenate the strings (these may be tens of
511          # megabytes).
512
513          src_sha1 = sha1()
514          for p in src:
515            src_sha1.update(p)
516          tgt_sha1 = sha1()
517          tgt_size = 0
518          for p in tgt:
519            tgt_sha1.update(p)
520            tgt_size += len(p)
521
522          if src_sha1.digest() == tgt_sha1.digest():
523            # These are identical; we don't need to generate a patch,
524            # just issue copy commands on the device.
525            xf.style = "move"
526          else:
527            # For files in zip format (eg, APKs, JARs, etc.) we would
528            # like to use imgdiff -z if possible (because it usually
529            # produces significantly smaller patches than bsdiff).
530            # This is permissible if:
531            #
532            #  - the source and target files are monotonic (ie, the
533            #    data is stored with blocks in increasing order), and
534            #  - we haven't removed any blocks from the source set.
535            #
536            # If these conditions are satisfied then appending all the
537            # blocks in the set together in order will produce a valid
538            # zip file (plus possibly extra zeros in the last block),
539            # which is what imgdiff needs to operate.  (imgdiff is
540            # fine with extra zeros at the end of the file.)
541            imgdiff = (xf.intact and
542                       xf.tgt_name.split(".")[-1].lower()
543                       in ("apk", "jar", "zip"))
544            xf.style = "imgdiff" if imgdiff else "bsdiff"
545            diff_q.append((tgt_size, src, tgt, xf, patch_num))
546            patch_num += 1
547
548        else:
549          assert False, "unknown style " + xf.style
550
551    if diff_q:
552      if self.threads > 1:
553        print("Computing patches (using %d threads)..." % (self.threads,))
554      else:
555        print("Computing patches...")
556      diff_q.sort()
557
558      patches = [None] * patch_num
559
560      # TODO: Rewrite with multiprocessing.ThreadPool?
561      lock = threading.Lock()
562      def diff_worker():
563        while True:
564          with lock:
565            if not diff_q:
566              return
567            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
568          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
569          size = len(patch)
570          with lock:
571            patches[patchnum] = (patch, xf)
572            print("%10d %10d (%6.2f%%) %7s %s" % (
573                size, tgt_size, size * 100.0 / tgt_size, xf.style,
574                xf.tgt_name if xf.tgt_name == xf.src_name else (
575                    xf.tgt_name + " (from " + xf.src_name + ")")))
576
577      threads = [threading.Thread(target=diff_worker)
578                 for _ in range(self.threads)]
579      for th in threads:
580        th.start()
581      while threads:
582        threads.pop().join()
583    else:
584      patches = []
585
586    p = 0
587    with open(prefix + ".patch.dat", "wb") as patch_f:
588      for patch, xf in patches:
589        xf.patch_start = p
590        xf.patch_len = len(patch)
591        patch_f.write(patch)
592        p += len(patch)
593
594  def AssertSequenceGood(self):
595    # Simulate the sequences of transfers we will output, and check that:
596    # - we never read a block after writing it, and
597    # - we write every block we care about exactly once.
598
599    # Start with no blocks having been touched yet.
600    touched = RangeSet()
601
602    # Imagine processing the transfers in order.
603    for xf in self.transfers:
604      # Check that the input blocks for this transfer haven't yet been touched.
605
606      x = xf.src_ranges
607      if self.version >= 2:
608        for _, sr in xf.use_stash:
609          x = x.subtract(sr)
610
611      assert not touched.overlaps(x)
612      # Check that the output blocks for this transfer haven't yet been touched.
613      assert not touched.overlaps(xf.tgt_ranges)
614      # Touch all the blocks written by this transfer.
615      touched = touched.union(xf.tgt_ranges)
616
617    # Check that we've written every target block.
618    assert touched == self.tgt.care_map
619
620  def ImproveVertexSequence(self):
621    print("Improving vertex order...")
622
623    # At this point our digraph is acyclic; we reversed any edges that
624    # were backwards in the heuristically-generated sequence.  The
625    # previously-generated order is still acceptable, but we hope to
626    # find a better order that needs less memory for stashed data.
627    # Now we do a topological sort to generate a new vertex order,
628    # using a greedy algorithm to choose which vertex goes next
629    # whenever we have a choice.
630
631    # Make a copy of the edge set; this copy will get destroyed by the
632    # algorithm.
633    for xf in self.transfers:
634      xf.incoming = xf.goes_after.copy()
635      xf.outgoing = xf.goes_before.copy()
636
637    L = []   # the new vertex order
638
639    # S is the set of sources in the remaining graph; we always choose
640    # the one that leaves the least amount of stashed data after it's
641    # executed.
642    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
643         if not u.incoming]
644    heapq.heapify(S)
645
646    while S:
647      _, _, xf = heapq.heappop(S)
648      L.append(xf)
649      for u in xf.outgoing:
650        del u.incoming[xf]
651        if not u.incoming:
652          heapq.heappush(S, (u.NetStashChange(), u.order, u))
653
654    # if this fails then our graph had a cycle.
655    assert len(L) == len(self.transfers)
656
657    self.transfers = L
658    for i, xf in enumerate(L):
659      xf.order = i
660
661  def RemoveBackwardEdges(self):
662    print("Removing backward edges...")
663    in_order = 0
664    out_of_order = 0
665    lost_source = 0
666
667    for xf in self.transfers:
668      lost = 0
669      size = xf.src_ranges.size()
670      for u in xf.goes_before:
671        # xf should go before u
672        if xf.order < u.order:
673          # it does, hurray!
674          in_order += 1
675        else:
676          # it doesn't, boo.  trim the blocks that u writes from xf's
677          # source, so that xf can go after u.
678          out_of_order += 1
679          assert xf.src_ranges.overlaps(u.tgt_ranges)
680          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
681          xf.intact = False
682
683      if xf.style == "diff" and not xf.src_ranges:
684        # nothing left to diff from; treat as new data
685        xf.style = "new"
686
687      lost = size - xf.src_ranges.size()
688      lost_source += lost
689
690    print(("  %d/%d dependencies (%.2f%%) were violated; "
691           "%d source blocks removed.") %
692          (out_of_order, in_order + out_of_order,
693           (out_of_order * 100.0 / (in_order + out_of_order))
694           if (in_order + out_of_order) else 0.0,
695           lost_source))
696
697  def ReverseBackwardEdges(self):
698    print("Reversing backward edges...")
699    in_order = 0
700    out_of_order = 0
701    stashes = 0
702    stash_size = 0
703
704    for xf in self.transfers:
705      for u in xf.goes_before.copy():
706        # xf should go before u
707        if xf.order < u.order:
708          # it does, hurray!
709          in_order += 1
710        else:
711          # it doesn't, boo.  modify u to stash the blocks that it
712          # writes that xf wants to read, and then require u to go
713          # before xf.
714          out_of_order += 1
715
716          overlap = xf.src_ranges.intersect(u.tgt_ranges)
717          assert overlap
718
719          u.stash_before.append((stashes, overlap))
720          xf.use_stash.append((stashes, overlap))
721          stashes += 1
722          stash_size += overlap.size()
723
724          # reverse the edge direction; now xf must go after u
725          del xf.goes_before[u]
726          del u.goes_after[xf]
727          xf.goes_after[u] = None    # value doesn't matter
728          u.goes_before[xf] = None
729
730    print(("  %d/%d dependencies (%.2f%%) were violated; "
731           "%d source blocks stashed.") %
732          (out_of_order, in_order + out_of_order,
733           (out_of_order * 100.0 / (in_order + out_of_order))
734           if (in_order + out_of_order) else 0.0,
735           stash_size))
736
737  def FindVertexSequence(self):
738    print("Finding vertex sequence...")
739
740    # This is based on "A Fast & Effective Heuristic for the Feedback
741    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
742    # it as starting with the digraph G and moving all the vertices to
743    # be on a horizontal line in some order, trying to minimize the
744    # number of edges that end up pointing to the left.  Left-pointing
745    # edges will get removed to turn the digraph into a DAG.  In this
746    # case each edge has a weight which is the number of source blocks
747    # we'll lose if that edge is removed; we try to minimize the total
748    # weight rather than just the number of edges.
749
750    # Make a copy of the edge set; this copy will get destroyed by the
751    # algorithm.
752    for xf in self.transfers:
753      xf.incoming = xf.goes_after.copy()
754      xf.outgoing = xf.goes_before.copy()
755
756    # We use an OrderedDict instead of just a set so that the output
757    # is repeatable; otherwise it would depend on the hash values of
758    # the transfer objects.
759    G = OrderedDict()
760    for xf in self.transfers:
761      G[xf] = None
762    s1 = deque()  # the left side of the sequence, built from left to right
763    s2 = deque()  # the right side of the sequence, built from right to left
764
765    while G:
766
767      # Put all sinks at the end of the sequence.
768      while True:
769        sinks = [u for u in G if not u.outgoing]
770        if not sinks:
771          break
772        for u in sinks:
773          s2.appendleft(u)
774          del G[u]
775          for iu in u.incoming:
776            del iu.outgoing[u]
777
778      # Put all the sources at the beginning of the sequence.
779      while True:
780        sources = [u for u in G if not u.incoming]
781        if not sources:
782          break
783        for u in sources:
784          s1.append(u)
785          del G[u]
786          for iu in u.outgoing:
787            del iu.incoming[u]
788
789      if not G:
790        break
791
792      # Find the "best" vertex to put next.  "Best" is the one that
793      # maximizes the net difference in source blocks saved we get by
794      # pretending it's a source rather than a sink.
795
796      max_d = None
797      best_u = None
798      for u in G:
799        d = sum(u.outgoing.values()) - sum(u.incoming.values())
800        if best_u is None or d > max_d:
801          max_d = d
802          best_u = u
803
804      u = best_u
805      s1.append(u)
806      del G[u]
807      for iu in u.outgoing:
808        del iu.incoming[u]
809      for iu in u.incoming:
810        del iu.outgoing[u]
811
812    # Now record the sequence in the 'order' field of each transfer,
813    # and by rearranging self.transfers to be in the chosen sequence.
814
815    new_transfers = []
816    for x in itertools.chain(s1, s2):
817      x.order = len(new_transfers)
818      new_transfers.append(x)
819      del x.incoming
820      del x.outgoing
821
822    self.transfers = new_transfers
823
824  def GenerateDigraph(self):
825    print("Generating digraph...")
826    for a in self.transfers:
827      for b in self.transfers:
828        if a is b:
829          continue
830
831        # If the blocks written by A are read by B, then B needs to go before A.
832        i = a.tgt_ranges.intersect(b.src_ranges)
833        if i:
834          if b.src_name == "__ZERO":
835            # the cost of removing source blocks for the __ZERO domain
836            # is (nearly) zero.
837            size = 0
838          else:
839            size = i.size()
840          b.goes_before[a] = size
841          a.goes_after[b] = size
842
843  def FindTransfers(self):
844    empty = RangeSet()
845    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
846      if tgt_fn == "__ZERO":
847        # the special "__ZERO" domain is all the blocks not contained
848        # in any file and that are filled with zeros.  We have a
849        # special transfer style for zero blocks.
850        src_ranges = self.src.file_map.get("__ZERO", empty)
851        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
852                 "zero", self.transfers)
853        continue
854
855      elif tgt_fn == "__COPY":
856        # "__COPY" domain includes all the blocks not contained in any
857        # file and that need to be copied unconditionally to the target.
858        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
859        continue
860
861      elif tgt_fn in self.src.file_map:
862        # Look for an exact pathname match in the source.
863        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
864                 "diff", self.transfers)
865        continue
866
867      b = os.path.basename(tgt_fn)
868      if b in self.src_basenames:
869        # Look for an exact basename match in the source.
870        src_fn = self.src_basenames[b]
871        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
872                 "diff", self.transfers)
873        continue
874
875      b = re.sub("[0-9]+", "#", b)
876      if b in self.src_numpatterns:
877        # Look for a 'number pattern' match (a basename match after
878        # all runs of digits are replaced by "#").  (This is useful
879        # for .so files that contain version numbers in the filename
880        # that get bumped.)
881        src_fn = self.src_numpatterns[b]
882        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
883                 "diff", self.transfers)
884        continue
885
886      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
887
888  def AbbreviateSourceNames(self):
889    for k in self.src.file_map.keys():
890      b = os.path.basename(k)
891      self.src_basenames[b] = k
892      b = re.sub("[0-9]+", "#", b)
893      self.src_numpatterns[b] = k
894
895  @staticmethod
896  def AssertPartition(total, seq):
897    """Assert that all the RangeSets in 'seq' form a partition of the
898    'total' RangeSet (ie, they are nonintersecting and their union
899    equals 'total')."""
900    so_far = RangeSet()
901    for i in seq:
902      assert not so_far.overlaps(i)
903      so_far = so_far.union(i)
904    assert so_far == total
905