blockimgdiff.py revision dd67a295cc91ad636b81b705b5bedda945035568
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import pprint
24import re
25import subprocess
26import sys
27import threading
28import tempfile
29
30from rangelib import *
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72class EmptyImage(object):
73  """A zero-length image."""
74  blocksize = 4096
75  care_map = RangeSet()
76  total_blocks = 0
77  file_map = {}
78  def ReadRangeSet(self, ranges):
79    return ()
80  def TotalSha1(self):
81    return sha1().hexdigest()
82
83
84class DataImage(object):
85  """An image wrapped around a single string of data."""
86
87  def __init__(self, data, trim=False, pad=False):
88    self.data = data
89    self.blocksize = 4096
90
91    assert not (trim and pad)
92
93    partial = len(self.data) % self.blocksize
94    if partial > 0:
95      if trim:
96        self.data = self.data[:-partial]
97      elif pad:
98        self.data += '\0' * (self.blocksize - partial)
99      else:
100        raise ValueError(("data for DataImage must be multiple of %d bytes "
101                          "unless trim or pad is specified") %
102                         (self.blocksize,))
103
104    assert len(self.data) % self.blocksize == 0
105
106    self.total_blocks = len(self.data) / self.blocksize
107    self.care_map = RangeSet(data=(0, self.total_blocks))
108
109    zero_blocks = []
110    nonzero_blocks = []
111    reference = '\0' * self.blocksize
112
113    for i in range(self.total_blocks):
114      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
115      if d == reference:
116        zero_blocks.append(i)
117        zero_blocks.append(i+1)
118      else:
119        nonzero_blocks.append(i)
120        nonzero_blocks.append(i+1)
121
122    self.file_map = {"__ZERO": RangeSet(zero_blocks),
123                     "__NONZERO": RangeSet(nonzero_blocks)}
124
125  def ReadRangeSet(self, ranges):
126    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
127
128  def TotalSha1(self):
129    if not hasattr(self, "sha1"):
130      self.sha1 = sha1(self.data).hexdigest()
131    return self.sha1
132
133
134class Transfer(object):
135  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
136    self.tgt_name = tgt_name
137    self.src_name = src_name
138    self.tgt_ranges = tgt_ranges
139    self.src_ranges = src_ranges
140    self.style = style
141    self.intact = (getattr(tgt_ranges, "monotonic", False) and
142                   getattr(src_ranges, "monotonic", False))
143    self.goes_before = {}
144    self.goes_after = {}
145
146    self.stash_before = []
147    self.use_stash = []
148
149    self.id = len(by_id)
150    by_id.append(self)
151
152  def NetStashChange(self):
153    return (sum(sr.size() for (_, sr) in self.stash_before) -
154            sum(sr.size() for (_, sr) in self.use_stash))
155
156  def __str__(self):
157    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
158            " to " + str(self.tgt_ranges) + ">")
159
160
161# BlockImageDiff works on two image objects.  An image object is
162# anything that provides the following attributes:
163#
164#    blocksize: the size in bytes of a block, currently must be 4096.
165#
166#    total_blocks: the total size of the partition/image, in blocks.
167#
168#    care_map: a RangeSet containing which blocks (in the range [0,
169#      total_blocks) we actually care about; i.e. which blocks contain
170#      data.
171#
172#    file_map: a dict that partitions the blocks contained in care_map
173#      into smaller domains that are useful for doing diffs on.
174#      (Typically a domain is a file, and the key in file_map is the
175#      pathname.)
176#
177#    ReadRangeSet(): a function that takes a RangeSet and returns the
178#      data contained in the image blocks of that RangeSet.  The data
179#      is returned as a list or tuple of strings; concatenating the
180#      elements together should produce the requested data.
181#      Implementations are free to break up the data into list/tuple
182#      elements in any way that is convenient.
183#
184#    TotalSha1(): a function that returns (as a hex string) the SHA-1
185#      hash of all the data in the image (ie, all the blocks in the
186#      care_map)
187#
188# When creating a BlockImageDiff, the src image may be None, in which
189# case the list of transfers produced will never read from the
190# original image.
191
192class BlockImageDiff(object):
193  def __init__(self, tgt, src=None, threads=None, version=3):
194    if threads is None:
195      threads = multiprocessing.cpu_count() // 2
196      if threads == 0: threads = 1
197    self.threads = threads
198    self.version = version
199
200    assert version in (1, 2, 3)
201
202    self.tgt = tgt
203    if src is None:
204      src = EmptyImage()
205    self.src = src
206
207    # The updater code that installs the patch always uses 4k blocks.
208    assert tgt.blocksize == 4096
209    assert src.blocksize == 4096
210
211    # The range sets in each filemap should comprise a partition of
212    # the care map.
213    self.AssertPartition(src.care_map, src.file_map.values())
214    self.AssertPartition(tgt.care_map, tgt.file_map.values())
215
216  def Compute(self, prefix):
217    # When looking for a source file to use as the diff input for a
218    # target file, we try:
219    #   1) an exact path match if available, otherwise
220    #   2) a exact basename match if available, otherwise
221    #   3) a basename match after all runs of digits are replaced by
222    #      "#" if available, otherwise
223    #   4) we have no source for this target.
224    self.AbbreviateSourceNames()
225    self.FindTransfers()
226
227    # Find the ordering dependencies among transfers (this is O(n^2)
228    # in the number of transfers).
229    self.GenerateDigraph()
230    # Find a sequence of transfers that satisfies as many ordering
231    # dependencies as possible (heuristically).
232    self.FindVertexSequence()
233    # Fix up the ordering dependencies that the sequence didn't
234    # satisfy.
235    if self.version == 1:
236      self.RemoveBackwardEdges()
237    else:
238      self.ReverseBackwardEdges()
239      self.ImproveVertexSequence()
240
241    # Double-check our work.
242    self.AssertSequenceGood()
243
244    self.ComputePatches(prefix)
245    self.WriteTransfers(prefix)
246
247  def HashBlocks(self, source, ranges):
248    data = source.ReadRangeSet(ranges)
249    ctx = sha1()
250
251    for p in data:
252      ctx.update(p)
253
254    return ctx.hexdigest()
255
256  def WriteTransfers(self, prefix):
257    out = []
258
259    total = 0
260    performs_read = False
261
262    stashes = {}
263    stashed_blocks = 0
264    max_stashed_blocks = 0
265
266    free_stash_ids = []
267    next_stash_id = 0
268
269    for xf in self.transfers:
270
271      if self.version < 2:
272        assert not xf.stash_before
273        assert not xf.use_stash
274
275      for s, sr in xf.stash_before:
276        assert s not in stashes
277        if free_stash_ids:
278          sid = heapq.heappop(free_stash_ids)
279        else:
280          sid = next_stash_id
281          next_stash_id += 1
282        stashes[s] = sid
283        stashed_blocks += sr.size()
284        if self.version == 2:
285          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
286        else:
287          sh = self.HashBlocks(self.src, sr)
288          if sh in stashes:
289            stashes[sh] += 1
290          else:
291            stashes[sh] = 1
292            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
293
294      if stashed_blocks > max_stashed_blocks:
295        max_stashed_blocks = stashed_blocks
296
297      if self.version == 1:
298        src_string = xf.src_ranges.to_string_raw()
299      elif self.version >= 2:
300
301        #   <# blocks> <src ranges>
302        #     OR
303        #   <# blocks> <src ranges> <src locs> <stash refs...>
304        #     OR
305        #   <# blocks> - <stash refs...>
306
307        size = xf.src_ranges.size()
308        src_string = [str(size)]
309        free_string = []
310
311        unstashed_src_ranges = xf.src_ranges
312        mapped_stashes = []
313        for s, sr in xf.use_stash:
314          sid = stashes.pop(s)
315          stashed_blocks -= sr.size()
316          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
317          sh = self.HashBlocks(self.src, sr)
318          sr = xf.src_ranges.map_within(sr)
319          mapped_stashes.append(sr)
320          if self.version == 2:
321            src_string.append("%d:%s" % (sid, sr.to_string_raw()))
322          else:
323            assert sh in stashes
324            src_string.append("%s:%s" % (sh, sr.to_string_raw()))
325            stashes[sh] -= 1
326            if stashes[sh] == 0:
327              free_string.append("free %s\n" % (sh))
328              stashes.pop(sh)
329          heapq.heappush(free_stash_ids, sid)
330
331        if unstashed_src_ranges:
332          src_string.insert(1, unstashed_src_ranges.to_string_raw())
333          if xf.use_stash:
334            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
335            src_string.insert(2, mapped_unstashed.to_string_raw())
336            mapped_stashes.append(mapped_unstashed)
337            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
338        else:
339          src_string.insert(1, "-")
340          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
341
342        src_string = " ".join(src_string)
343
344      # all versions:
345      #   zero <rangeset>
346      #   new <rangeset>
347      #   erase <rangeset>
348      #
349      # version 1:
350      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
351      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
352      #   move <src rangeset> <tgt rangeset>
353      #
354      # version 2:
355      #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
356      #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
357      #   move <tgt rangeset> <src_string>
358      #
359      # version 3:
360      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string>
361      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string>
362      #   move hash <tgt rangeset> <src_string>
363
364      tgt_size = xf.tgt_ranges.size()
365
366      if xf.style == "new":
367        assert xf.tgt_ranges
368        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
369        total += tgt_size
370      elif xf.style == "move":
371        performs_read = True
372        assert xf.tgt_ranges
373        assert xf.src_ranges.size() == tgt_size
374        if xf.src_ranges != xf.tgt_ranges:
375          if self.version == 1:
376            out.append("%s %s %s\n" % (
377                xf.style,
378                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
379          elif self.version == 2:
380            out.append("%s %s %s\n" % (
381                xf.style,
382                xf.tgt_ranges.to_string_raw(), src_string))
383          elif self.version >= 3:
384            out.append("%s %s %s %s\n" % (
385                xf.style,
386                self.HashBlocks(self.tgt, xf.tgt_ranges),
387                xf.tgt_ranges.to_string_raw(), src_string))
388          total += tgt_size
389      elif xf.style in ("bsdiff", "imgdiff"):
390        performs_read = True
391        assert xf.tgt_ranges
392        assert xf.src_ranges
393        if self.version == 1:
394          out.append("%s %d %d %s %s\n" % (
395              xf.style, xf.patch_start, xf.patch_len,
396              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
397        elif self.version == 2:
398          out.append("%s %d %d %s %s\n" % (
399              xf.style, xf.patch_start, xf.patch_len,
400              xf.tgt_ranges.to_string_raw(), src_string))
401        elif self.version >= 3:
402          out.append("%s %d %d %s %s %s %s\n" % (
403              xf.style,
404              xf.patch_start, xf.patch_len,
405              self.HashBlocks(self.src, xf.src_ranges),
406              self.HashBlocks(self.tgt, xf.tgt_ranges),
407              xf.tgt_ranges.to_string_raw(), src_string))
408        total += tgt_size
409      elif xf.style == "zero":
410        assert xf.tgt_ranges
411        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
412        if to_zero:
413          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
414          total += to_zero.size()
415      else:
416        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
417
418      if free_string:
419        out.append("".join(free_string))
420
421
422      # sanity check: abort if we're going to need more than 512 MB if
423      # stash space
424      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
425
426    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
427    if performs_read:
428      # if some of the original data is used, then at the end we'll
429      # erase all the blocks on the partition that don't contain data
430      # in the new image.
431      new_dontcare = all_tgt.subtract(self.tgt.care_map)
432      if new_dontcare:
433        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
434    else:
435      # if nothing is read (ie, this is a full OTA), then we can start
436      # by erasing the entire partition.
437      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
438
439    out.insert(0, "%d\n" % (self.version,))   # format version number
440    out.insert(1, str(total) + "\n")
441    if self.version >= 2:
442      # version 2 only: after the total block count, we give the number
443      # of stash slots needed, and the maximum size needed (in blocks)
444      out.insert(2, str(next_stash_id) + "\n")
445      out.insert(3, str(max_stashed_blocks) + "\n")
446
447    with open(prefix + ".transfer.list", "wb") as f:
448      for i in out:
449        f.write(i)
450
451    if self.version >= 2:
452      print("max stashed blocks: %d  (%d bytes)\n" % (
453          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
454
455  def ComputePatches(self, prefix):
456    print("Reticulating splines...")
457    diff_q = []
458    patch_num = 0
459    with open(prefix + ".new.dat", "wb") as new_f:
460      for xf in self.transfers:
461        if xf.style == "zero":
462          pass
463        elif xf.style == "new":
464          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
465            new_f.write(piece)
466        elif xf.style == "diff":
467          src = self.src.ReadRangeSet(xf.src_ranges)
468          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
469
470          # We can't compare src and tgt directly because they may have
471          # the same content but be broken up into blocks differently, eg:
472          #
473          #    ["he", "llo"]  vs  ["h", "ello"]
474          #
475          # We want those to compare equal, ideally without having to
476          # actually concatenate the strings (these may be tens of
477          # megabytes).
478
479          src_sha1 = sha1()
480          for p in src:
481            src_sha1.update(p)
482          tgt_sha1 = sha1()
483          tgt_size = 0
484          for p in tgt:
485            tgt_sha1.update(p)
486            tgt_size += len(p)
487
488          if src_sha1.digest() == tgt_sha1.digest():
489            # These are identical; we don't need to generate a patch,
490            # just issue copy commands on the device.
491            xf.style = "move"
492          else:
493            # For files in zip format (eg, APKs, JARs, etc.) we would
494            # like to use imgdiff -z if possible (because it usually
495            # produces significantly smaller patches than bsdiff).
496            # This is permissible if:
497            #
498            #  - the source and target files are monotonic (ie, the
499            #    data is stored with blocks in increasing order), and
500            #  - we haven't removed any blocks from the source set.
501            #
502            # If these conditions are satisfied then appending all the
503            # blocks in the set together in order will produce a valid
504            # zip file (plus possibly extra zeros in the last block),
505            # which is what imgdiff needs to operate.  (imgdiff is
506            # fine with extra zeros at the end of the file.)
507            imgdiff = (xf.intact and
508                       xf.tgt_name.split(".")[-1].lower()
509                       in ("apk", "jar", "zip"))
510            xf.style = "imgdiff" if imgdiff else "bsdiff"
511            diff_q.append((tgt_size, src, tgt, xf, patch_num))
512            patch_num += 1
513
514        else:
515          assert False, "unknown style " + xf.style
516
517    if diff_q:
518      if self.threads > 1:
519        print("Computing patches (using %d threads)..." % (self.threads,))
520      else:
521        print("Computing patches...")
522      diff_q.sort()
523
524      patches = [None] * patch_num
525
526      lock = threading.Lock()
527      def diff_worker():
528        while True:
529          with lock:
530            if not diff_q: return
531            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
532          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
533          size = len(patch)
534          with lock:
535            patches[patchnum] = (patch, xf)
536            print("%10d %10d (%6.2f%%) %7s %s" % (
537                size, tgt_size, size * 100.0 / tgt_size, xf.style,
538                xf.tgt_name if xf.tgt_name == xf.src_name else (
539                    xf.tgt_name + " (from " + xf.src_name + ")")))
540
541      threads = [threading.Thread(target=diff_worker)
542                 for i in range(self.threads)]
543      for th in threads:
544        th.start()
545      while threads:
546        threads.pop().join()
547    else:
548      patches = []
549
550    p = 0
551    with open(prefix + ".patch.dat", "wb") as patch_f:
552      for patch, xf in patches:
553        xf.patch_start = p
554        xf.patch_len = len(patch)
555        patch_f.write(patch)
556        p += len(patch)
557
558  def AssertSequenceGood(self):
559    # Simulate the sequences of transfers we will output, and check that:
560    # - we never read a block after writing it, and
561    # - we write every block we care about exactly once.
562
563    # Start with no blocks having been touched yet.
564    touched = RangeSet()
565
566    # Imagine processing the transfers in order.
567    for xf in self.transfers:
568      # Check that the input blocks for this transfer haven't yet been touched.
569
570      x = xf.src_ranges
571      if self.version >= 2:
572        for _, sr in xf.use_stash:
573          x = x.subtract(sr)
574
575      assert not touched.overlaps(x)
576      # Check that the output blocks for this transfer haven't yet been touched.
577      assert not touched.overlaps(xf.tgt_ranges)
578      # Touch all the blocks written by this transfer.
579      touched = touched.union(xf.tgt_ranges)
580
581    # Check that we've written every target block.
582    assert touched == self.tgt.care_map
583
584  def ImproveVertexSequence(self):
585    print("Improving vertex order...")
586
587    # At this point our digraph is acyclic; we reversed any edges that
588    # were backwards in the heuristically-generated sequence.  The
589    # previously-generated order is still acceptable, but we hope to
590    # find a better order that needs less memory for stashed data.
591    # Now we do a topological sort to generate a new vertex order,
592    # using a greedy algorithm to choose which vertex goes next
593    # whenever we have a choice.
594
595    # Make a copy of the edge set; this copy will get destroyed by the
596    # algorithm.
597    for xf in self.transfers:
598      xf.incoming = xf.goes_after.copy()
599      xf.outgoing = xf.goes_before.copy()
600
601    L = []   # the new vertex order
602
603    # S is the set of sources in the remaining graph; we always choose
604    # the one that leaves the least amount of stashed data after it's
605    # executed.
606    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
607         if not u.incoming]
608    heapq.heapify(S)
609
610    while S:
611      _, _, xf = heapq.heappop(S)
612      L.append(xf)
613      for u in xf.outgoing:
614        del u.incoming[xf]
615        if not u.incoming:
616          heapq.heappush(S, (u.NetStashChange(), u.order, u))
617
618    # if this fails then our graph had a cycle.
619    assert len(L) == len(self.transfers)
620
621    self.transfers = L
622    for i, xf in enumerate(L):
623      xf.order = i
624
625  def RemoveBackwardEdges(self):
626    print("Removing backward edges...")
627    in_order = 0
628    out_of_order = 0
629    lost_source = 0
630
631    for xf in self.transfers:
632      lost = 0
633      size = xf.src_ranges.size()
634      for u in xf.goes_before:
635        # xf should go before u
636        if xf.order < u.order:
637          # it does, hurray!
638          in_order += 1
639        else:
640          # it doesn't, boo.  trim the blocks that u writes from xf's
641          # source, so that xf can go after u.
642          out_of_order += 1
643          assert xf.src_ranges.overlaps(u.tgt_ranges)
644          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
645          xf.intact = False
646
647      if xf.style == "diff" and not xf.src_ranges:
648        # nothing left to diff from; treat as new data
649        xf.style = "new"
650
651      lost = size - xf.src_ranges.size()
652      lost_source += lost
653
654    print(("  %d/%d dependencies (%.2f%%) were violated; "
655           "%d source blocks removed.") %
656          (out_of_order, in_order + out_of_order,
657           (out_of_order * 100.0 / (in_order + out_of_order))
658           if (in_order + out_of_order) else 0.0,
659           lost_source))
660
661  def ReverseBackwardEdges(self):
662    print("Reversing backward edges...")
663    in_order = 0
664    out_of_order = 0
665    stashes = 0
666    stash_size = 0
667
668    for xf in self.transfers:
669      lost = 0
670      size = xf.src_ranges.size()
671      for u in xf.goes_before.copy():
672        # xf should go before u
673        if xf.order < u.order:
674          # it does, hurray!
675          in_order += 1
676        else:
677          # it doesn't, boo.  modify u to stash the blocks that it
678          # writes that xf wants to read, and then require u to go
679          # before xf.
680          out_of_order += 1
681
682          overlap = xf.src_ranges.intersect(u.tgt_ranges)
683          assert overlap
684
685          u.stash_before.append((stashes, overlap))
686          xf.use_stash.append((stashes, overlap))
687          stashes += 1
688          stash_size += overlap.size()
689
690          # reverse the edge direction; now xf must go after u
691          del xf.goes_before[u]
692          del u.goes_after[xf]
693          xf.goes_after[u] = None    # value doesn't matter
694          u.goes_before[xf] = None
695
696    print(("  %d/%d dependencies (%.2f%%) were violated; "
697           "%d source blocks stashed.") %
698          (out_of_order, in_order + out_of_order,
699           (out_of_order * 100.0 / (in_order + out_of_order))
700           if (in_order + out_of_order) else 0.0,
701           stash_size))
702
703  def FindVertexSequence(self):
704    print("Finding vertex sequence...")
705
706    # This is based on "A Fast & Effective Heuristic for the Feedback
707    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
708    # it as starting with the digraph G and moving all the vertices to
709    # be on a horizontal line in some order, trying to minimize the
710    # number of edges that end up pointing to the left.  Left-pointing
711    # edges will get removed to turn the digraph into a DAG.  In this
712    # case each edge has a weight which is the number of source blocks
713    # we'll lose if that edge is removed; we try to minimize the total
714    # weight rather than just the number of edges.
715
716    # Make a copy of the edge set; this copy will get destroyed by the
717    # algorithm.
718    for xf in self.transfers:
719      xf.incoming = xf.goes_after.copy()
720      xf.outgoing = xf.goes_before.copy()
721
722    # We use an OrderedDict instead of just a set so that the output
723    # is repeatable; otherwise it would depend on the hash values of
724    # the transfer objects.
725    G = OrderedDict()
726    for xf in self.transfers:
727      G[xf] = None
728    s1 = deque()  # the left side of the sequence, built from left to right
729    s2 = deque()  # the right side of the sequence, built from right to left
730
731    while G:
732
733      # Put all sinks at the end of the sequence.
734      while True:
735        sinks = [u for u in G if not u.outgoing]
736        if not sinks: break
737        for u in sinks:
738          s2.appendleft(u)
739          del G[u]
740          for iu in u.incoming:
741            del iu.outgoing[u]
742
743      # Put all the sources at the beginning of the sequence.
744      while True:
745        sources = [u for u in G if not u.incoming]
746        if not sources: break
747        for u in sources:
748          s1.append(u)
749          del G[u]
750          for iu in u.outgoing:
751            del iu.incoming[u]
752
753      if not G: break
754
755      # Find the "best" vertex to put next.  "Best" is the one that
756      # maximizes the net difference in source blocks saved we get by
757      # pretending it's a source rather than a sink.
758
759      max_d = None
760      best_u = None
761      for u in G:
762        d = sum(u.outgoing.values()) - sum(u.incoming.values())
763        if best_u is None or d > max_d:
764          max_d = d
765          best_u = u
766
767      u = best_u
768      s1.append(u)
769      del G[u]
770      for iu in u.outgoing:
771        del iu.incoming[u]
772      for iu in u.incoming:
773        del iu.outgoing[u]
774
775    # Now record the sequence in the 'order' field of each transfer,
776    # and by rearranging self.transfers to be in the chosen sequence.
777
778    new_transfers = []
779    for x in itertools.chain(s1, s2):
780      x.order = len(new_transfers)
781      new_transfers.append(x)
782      del x.incoming
783      del x.outgoing
784
785    self.transfers = new_transfers
786
787  def GenerateDigraph(self):
788    print("Generating digraph...")
789    for a in self.transfers:
790      for b in self.transfers:
791        if a is b: continue
792
793        # If the blocks written by A are read by B, then B needs to go before A.
794        i = a.tgt_ranges.intersect(b.src_ranges)
795        if i:
796          if b.src_name == "__ZERO":
797            # the cost of removing source blocks for the __ZERO domain
798            # is (nearly) zero.
799            size = 0
800          else:
801            size = i.size()
802          b.goes_before[a] = size
803          a.goes_after[b] = size
804
805  def FindTransfers(self):
806    self.transfers = []
807    empty = RangeSet()
808    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
809      if tgt_fn == "__ZERO":
810        # the special "__ZERO" domain is all the blocks not contained
811        # in any file and that are filled with zeros.  We have a
812        # special transfer style for zero blocks.
813        src_ranges = self.src.file_map.get("__ZERO", empty)
814        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
815                 "zero", self.transfers)
816        continue
817
818      elif tgt_fn in self.src.file_map:
819        # Look for an exact pathname match in the source.
820        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
821                 "diff", self.transfers)
822        continue
823
824      b = os.path.basename(tgt_fn)
825      if b in self.src_basenames:
826        # Look for an exact basename match in the source.
827        src_fn = self.src_basenames[b]
828        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
829                 "diff", self.transfers)
830        continue
831
832      b = re.sub("[0-9]+", "#", b)
833      if b in self.src_numpatterns:
834        # Look for a 'number pattern' match (a basename match after
835        # all runs of digits are replaced by "#").  (This is useful
836        # for .so files that contain version numbers in the filename
837        # that get bumped.)
838        src_fn = self.src_numpatterns[b]
839        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
840                 "diff", self.transfers)
841        continue
842
843      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
844
845  def AbbreviateSourceNames(self):
846    self.src_basenames = {}
847    self.src_numpatterns = {}
848
849    for k in self.src.file_map.keys():
850      b = os.path.basename(k)
851      self.src_basenames[b] = k
852      b = re.sub("[0-9]+", "#", b)
853      self.src_numpatterns[b] = k
854
855  @staticmethod
856  def AssertPartition(total, seq):
857    """Assert that all the RangeSets in 'seq' form a partition of the
858    'total' RangeSet (ie, they are nonintersecting and their union
859    equals 'total')."""
860    so_far = RangeSet()
861    for i in seq:
862      assert not so_far.overlaps(i)
863      so_far = so_far.union(i)
864    assert so_far == total
865