blockimgdiff.py revision 7b985f6aedad882c7b3332032bd9265ec42fc871
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import pprint
24import re
25import subprocess
26import sys
27import threading
28import tempfile
29
30from rangelib import *
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72class EmptyImage(object):
73  """A zero-length image."""
74  blocksize = 4096
75  care_map = RangeSet()
76  total_blocks = 0
77  file_map = {}
78  def ReadRangeSet(self, ranges):
79    return ()
80  def TotalSha1(self):
81    return sha1().hexdigest()
82
83
84class DataImage(object):
85  """An image wrapped around a single string of data."""
86
87  def __init__(self, data, trim=False, pad=False):
88    self.data = data
89    self.blocksize = 4096
90
91    assert not (trim and pad)
92
93    partial = len(self.data) % self.blocksize
94    if partial > 0:
95      if trim:
96        self.data = self.data[:-partial]
97      elif pad:
98        self.data += '\0' * (self.blocksize - partial)
99      else:
100        raise ValueError(("data for DataImage must be multiple of %d bytes "
101                          "unless trim or pad is specified") %
102                         (self.blocksize,))
103
104    assert len(self.data) % self.blocksize == 0
105
106    self.total_blocks = len(self.data) / self.blocksize
107    self.care_map = RangeSet(data=(0, self.total_blocks))
108
109    zero_blocks = []
110    nonzero_blocks = []
111    reference = '\0' * self.blocksize
112
113    for i in range(self.total_blocks):
114      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
115      if d == reference:
116        zero_blocks.append(i)
117        zero_blocks.append(i+1)
118      else:
119        nonzero_blocks.append(i)
120        nonzero_blocks.append(i+1)
121
122    self.file_map = {"__ZERO": RangeSet(zero_blocks),
123                     "__NONZERO": RangeSet(nonzero_blocks)}
124
125  def ReadRangeSet(self, ranges):
126    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
127
128  def TotalSha1(self):
129    if not hasattr(self, "sha1"):
130      self.sha1 = sha1(self.data).hexdigest()
131    return self.sha1
132
133
134class Transfer(object):
135  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
136    self.tgt_name = tgt_name
137    self.src_name = src_name
138    self.tgt_ranges = tgt_ranges
139    self.src_ranges = src_ranges
140    self.style = style
141    self.intact = (getattr(tgt_ranges, "monotonic", False) and
142                   getattr(src_ranges, "monotonic", False))
143    self.goes_before = {}
144    self.goes_after = {}
145
146    self.stash_before = []
147    self.use_stash = []
148
149    self.id = len(by_id)
150    by_id.append(self)
151
152  def NetStashChange(self):
153    return (sum(sr.size() for (_, sr) in self.stash_before) -
154            sum(sr.size() for (_, sr) in self.use_stash))
155
156  def __str__(self):
157    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
158            " to " + str(self.tgt_ranges) + ">")
159
160
161# BlockImageDiff works on two image objects.  An image object is
162# anything that provides the following attributes:
163#
164#    blocksize: the size in bytes of a block, currently must be 4096.
165#
166#    total_blocks: the total size of the partition/image, in blocks.
167#
168#    care_map: a RangeSet containing which blocks (in the range [0,
169#      total_blocks) we actually care about; i.e. which blocks contain
170#      data.
171#
172#    file_map: a dict that partitions the blocks contained in care_map
173#      into smaller domains that are useful for doing diffs on.
174#      (Typically a domain is a file, and the key in file_map is the
175#      pathname.)
176#
177#    ReadRangeSet(): a function that takes a RangeSet and returns the
178#      data contained in the image blocks of that RangeSet.  The data
179#      is returned as a list or tuple of strings; concatenating the
180#      elements together should produce the requested data.
181#      Implementations are free to break up the data into list/tuple
182#      elements in any way that is convenient.
183#
184#    TotalSha1(): a function that returns (as a hex string) the SHA-1
185#      hash of all the data in the image (ie, all the blocks in the
186#      care_map)
187#
188# When creating a BlockImageDiff, the src image may be None, in which
189# case the list of transfers produced will never read from the
190# original image.
191
192class BlockImageDiff(object):
193  def __init__(self, tgt, src=None, threads=None, version=3):
194    if threads is None:
195      threads = multiprocessing.cpu_count() // 2
196      if threads == 0: threads = 1
197    self.threads = threads
198    self.version = version
199
200    assert version in (1, 2, 3)
201
202    self.tgt = tgt
203    if src is None:
204      src = EmptyImage()
205    self.src = src
206
207    # The updater code that installs the patch always uses 4k blocks.
208    assert tgt.blocksize == 4096
209    assert src.blocksize == 4096
210
211    # The range sets in each filemap should comprise a partition of
212    # the care map.
213    self.AssertPartition(src.care_map, src.file_map.values())
214    self.AssertPartition(tgt.care_map, tgt.file_map.values())
215
216  def Compute(self, prefix):
217    # When looking for a source file to use as the diff input for a
218    # target file, we try:
219    #   1) an exact path match if available, otherwise
220    #   2) a exact basename match if available, otherwise
221    #   3) a basename match after all runs of digits are replaced by
222    #      "#" if available, otherwise
223    #   4) we have no source for this target.
224    self.AbbreviateSourceNames()
225    self.FindTransfers()
226
227    # Find the ordering dependencies among transfers (this is O(n^2)
228    # in the number of transfers).
229    self.GenerateDigraph()
230    # Find a sequence of transfers that satisfies as many ordering
231    # dependencies as possible (heuristically).
232    self.FindVertexSequence()
233    # Fix up the ordering dependencies that the sequence didn't
234    # satisfy.
235    if self.version == 1:
236      self.RemoveBackwardEdges()
237    else:
238      self.ReverseBackwardEdges()
239      self.ImproveVertexSequence()
240
241    # Double-check our work.
242    self.AssertSequenceGood()
243
244    self.ComputePatches(prefix)
245    self.WriteTransfers(prefix)
246
247  def HashBlocks(self, source, ranges):
248    data = source.ReadRangeSet(ranges)
249    ctx = sha1()
250
251    for p in data:
252      ctx.update(p)
253
254    return ctx.hexdigest()
255
256  def WriteTransfers(self, prefix):
257    out = []
258
259    total = 0
260    performs_read = False
261
262    stashes = {}
263    stashed_blocks = 0
264    max_stashed_blocks = 0
265
266    free_stash_ids = []
267    next_stash_id = 0
268
269    for xf in self.transfers:
270
271      if self.version < 2:
272        assert not xf.stash_before
273        assert not xf.use_stash
274
275      for s, sr in xf.stash_before:
276        assert s not in stashes
277        if free_stash_ids:
278          sid = heapq.heappop(free_stash_ids)
279        else:
280          sid = next_stash_id
281          next_stash_id += 1
282        stashes[s] = sid
283        stashed_blocks += sr.size()
284        if self.version == 2:
285          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
286        else:
287          sh = self.HashBlocks(self.src, sr)
288          if sh in stashes:
289            stashes[sh] += 1
290          else:
291            stashes[sh] = 1
292            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
293
294      if stashed_blocks > max_stashed_blocks:
295        max_stashed_blocks = stashed_blocks
296
297      free_string = []
298
299      if self.version == 1:
300        src_string = xf.src_ranges.to_string_raw()
301      elif self.version >= 2:
302
303        #   <# blocks> <src ranges>
304        #     OR
305        #   <# blocks> <src ranges> <src locs> <stash refs...>
306        #     OR
307        #   <# blocks> - <stash refs...>
308
309        size = xf.src_ranges.size()
310        src_string = [str(size)]
311
312        unstashed_src_ranges = xf.src_ranges
313        mapped_stashes = []
314        for s, sr in xf.use_stash:
315          sid = stashes.pop(s)
316          stashed_blocks -= sr.size()
317          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
318          sh = self.HashBlocks(self.src, sr)
319          sr = xf.src_ranges.map_within(sr)
320          mapped_stashes.append(sr)
321          if self.version == 2:
322            src_string.append("%d:%s" % (sid, sr.to_string_raw()))
323          else:
324            assert sh in stashes
325            src_string.append("%s:%s" % (sh, sr.to_string_raw()))
326            stashes[sh] -= 1
327            if stashes[sh] == 0:
328              free_string.append("free %s\n" % (sh))
329              stashes.pop(sh)
330          heapq.heappush(free_stash_ids, sid)
331
332        if unstashed_src_ranges:
333          src_string.insert(1, unstashed_src_ranges.to_string_raw())
334          if xf.use_stash:
335            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
336            src_string.insert(2, mapped_unstashed.to_string_raw())
337            mapped_stashes.append(mapped_unstashed)
338            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
339        else:
340          src_string.insert(1, "-")
341          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
342
343        src_string = " ".join(src_string)
344
345      # all versions:
346      #   zero <rangeset>
347      #   new <rangeset>
348      #   erase <rangeset>
349      #
350      # version 1:
351      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
352      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
353      #   move <src rangeset> <tgt rangeset>
354      #
355      # version 2:
356      #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
357      #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
358      #   move <tgt rangeset> <src_string>
359      #
360      # version 3:
361      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string>
362      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string>
363      #   move hash <tgt rangeset> <src_string>
364
365      tgt_size = xf.tgt_ranges.size()
366
367      if xf.style == "new":
368        assert xf.tgt_ranges
369        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
370        total += tgt_size
371      elif xf.style == "move":
372        performs_read = True
373        assert xf.tgt_ranges
374        assert xf.src_ranges.size() == tgt_size
375        if xf.src_ranges != xf.tgt_ranges:
376          if self.version == 1:
377            out.append("%s %s %s\n" % (
378                xf.style,
379                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
380          elif self.version == 2:
381            out.append("%s %s %s\n" % (
382                xf.style,
383                xf.tgt_ranges.to_string_raw(), src_string))
384          elif self.version >= 3:
385            out.append("%s %s %s %s\n" % (
386                xf.style,
387                self.HashBlocks(self.tgt, xf.tgt_ranges),
388                xf.tgt_ranges.to_string_raw(), src_string))
389          total += tgt_size
390      elif xf.style in ("bsdiff", "imgdiff"):
391        performs_read = True
392        assert xf.tgt_ranges
393        assert xf.src_ranges
394        if self.version == 1:
395          out.append("%s %d %d %s %s\n" % (
396              xf.style, xf.patch_start, xf.patch_len,
397              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
398        elif self.version == 2:
399          out.append("%s %d %d %s %s\n" % (
400              xf.style, xf.patch_start, xf.patch_len,
401              xf.tgt_ranges.to_string_raw(), src_string))
402        elif self.version >= 3:
403          out.append("%s %d %d %s %s %s %s\n" % (
404              xf.style,
405              xf.patch_start, xf.patch_len,
406              self.HashBlocks(self.src, xf.src_ranges),
407              self.HashBlocks(self.tgt, xf.tgt_ranges),
408              xf.tgt_ranges.to_string_raw(), src_string))
409        total += tgt_size
410      elif xf.style == "zero":
411        assert xf.tgt_ranges
412        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
413        if to_zero:
414          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
415          total += to_zero.size()
416      else:
417        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
418
419      if free_string:
420        out.append("".join(free_string))
421
422
423      # sanity check: abort if we're going to need more than 512 MB if
424      # stash space
425      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
426
427    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
428    if performs_read:
429      # if some of the original data is used, then at the end we'll
430      # erase all the blocks on the partition that don't contain data
431      # in the new image.
432      new_dontcare = all_tgt.subtract(self.tgt.care_map)
433      if new_dontcare:
434        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
435    else:
436      # if nothing is read (ie, this is a full OTA), then we can start
437      # by erasing the entire partition.
438      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
439
440    out.insert(0, "%d\n" % (self.version,))   # format version number
441    out.insert(1, str(total) + "\n")
442    if self.version >= 2:
443      # version 2 only: after the total block count, we give the number
444      # of stash slots needed, and the maximum size needed (in blocks)
445      out.insert(2, str(next_stash_id) + "\n")
446      out.insert(3, str(max_stashed_blocks) + "\n")
447
448    with open(prefix + ".transfer.list", "wb") as f:
449      for i in out:
450        f.write(i)
451
452    if self.version >= 2:
453      print("max stashed blocks: %d  (%d bytes)\n" % (
454          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
455
456  def ComputePatches(self, prefix):
457    print("Reticulating splines...")
458    diff_q = []
459    patch_num = 0
460    with open(prefix + ".new.dat", "wb") as new_f:
461      for xf in self.transfers:
462        if xf.style == "zero":
463          pass
464        elif xf.style == "new":
465          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
466            new_f.write(piece)
467        elif xf.style == "diff":
468          src = self.src.ReadRangeSet(xf.src_ranges)
469          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
470
471          # We can't compare src and tgt directly because they may have
472          # the same content but be broken up into blocks differently, eg:
473          #
474          #    ["he", "llo"]  vs  ["h", "ello"]
475          #
476          # We want those to compare equal, ideally without having to
477          # actually concatenate the strings (these may be tens of
478          # megabytes).
479
480          src_sha1 = sha1()
481          for p in src:
482            src_sha1.update(p)
483          tgt_sha1 = sha1()
484          tgt_size = 0
485          for p in tgt:
486            tgt_sha1.update(p)
487            tgt_size += len(p)
488
489          if src_sha1.digest() == tgt_sha1.digest():
490            # These are identical; we don't need to generate a patch,
491            # just issue copy commands on the device.
492            xf.style = "move"
493          else:
494            # For files in zip format (eg, APKs, JARs, etc.) we would
495            # like to use imgdiff -z if possible (because it usually
496            # produces significantly smaller patches than bsdiff).
497            # This is permissible if:
498            #
499            #  - the source and target files are monotonic (ie, the
500            #    data is stored with blocks in increasing order), and
501            #  - we haven't removed any blocks from the source set.
502            #
503            # If these conditions are satisfied then appending all the
504            # blocks in the set together in order will produce a valid
505            # zip file (plus possibly extra zeros in the last block),
506            # which is what imgdiff needs to operate.  (imgdiff is
507            # fine with extra zeros at the end of the file.)
508            imgdiff = (xf.intact and
509                       xf.tgt_name.split(".")[-1].lower()
510                       in ("apk", "jar", "zip"))
511            xf.style = "imgdiff" if imgdiff else "bsdiff"
512            diff_q.append((tgt_size, src, tgt, xf, patch_num))
513            patch_num += 1
514
515        else:
516          assert False, "unknown style " + xf.style
517
518    if diff_q:
519      if self.threads > 1:
520        print("Computing patches (using %d threads)..." % (self.threads,))
521      else:
522        print("Computing patches...")
523      diff_q.sort()
524
525      patches = [None] * patch_num
526
527      lock = threading.Lock()
528      def diff_worker():
529        while True:
530          with lock:
531            if not diff_q: return
532            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
533          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
534          size = len(patch)
535          with lock:
536            patches[patchnum] = (patch, xf)
537            print("%10d %10d (%6.2f%%) %7s %s" % (
538                size, tgt_size, size * 100.0 / tgt_size, xf.style,
539                xf.tgt_name if xf.tgt_name == xf.src_name else (
540                    xf.tgt_name + " (from " + xf.src_name + ")")))
541
542      threads = [threading.Thread(target=diff_worker)
543                 for i in range(self.threads)]
544      for th in threads:
545        th.start()
546      while threads:
547        threads.pop().join()
548    else:
549      patches = []
550
551    p = 0
552    with open(prefix + ".patch.dat", "wb") as patch_f:
553      for patch, xf in patches:
554        xf.patch_start = p
555        xf.patch_len = len(patch)
556        patch_f.write(patch)
557        p += len(patch)
558
559  def AssertSequenceGood(self):
560    # Simulate the sequences of transfers we will output, and check that:
561    # - we never read a block after writing it, and
562    # - we write every block we care about exactly once.
563
564    # Start with no blocks having been touched yet.
565    touched = RangeSet()
566
567    # Imagine processing the transfers in order.
568    for xf in self.transfers:
569      # Check that the input blocks for this transfer haven't yet been touched.
570
571      x = xf.src_ranges
572      if self.version >= 2:
573        for _, sr in xf.use_stash:
574          x = x.subtract(sr)
575
576      assert not touched.overlaps(x)
577      # Check that the output blocks for this transfer haven't yet been touched.
578      assert not touched.overlaps(xf.tgt_ranges)
579      # Touch all the blocks written by this transfer.
580      touched = touched.union(xf.tgt_ranges)
581
582    # Check that we've written every target block.
583    assert touched == self.tgt.care_map
584
585  def ImproveVertexSequence(self):
586    print("Improving vertex order...")
587
588    # At this point our digraph is acyclic; we reversed any edges that
589    # were backwards in the heuristically-generated sequence.  The
590    # previously-generated order is still acceptable, but we hope to
591    # find a better order that needs less memory for stashed data.
592    # Now we do a topological sort to generate a new vertex order,
593    # using a greedy algorithm to choose which vertex goes next
594    # whenever we have a choice.
595
596    # Make a copy of the edge set; this copy will get destroyed by the
597    # algorithm.
598    for xf in self.transfers:
599      xf.incoming = xf.goes_after.copy()
600      xf.outgoing = xf.goes_before.copy()
601
602    L = []   # the new vertex order
603
604    # S is the set of sources in the remaining graph; we always choose
605    # the one that leaves the least amount of stashed data after it's
606    # executed.
607    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
608         if not u.incoming]
609    heapq.heapify(S)
610
611    while S:
612      _, _, xf = heapq.heappop(S)
613      L.append(xf)
614      for u in xf.outgoing:
615        del u.incoming[xf]
616        if not u.incoming:
617          heapq.heappush(S, (u.NetStashChange(), u.order, u))
618
619    # if this fails then our graph had a cycle.
620    assert len(L) == len(self.transfers)
621
622    self.transfers = L
623    for i, xf in enumerate(L):
624      xf.order = i
625
626  def RemoveBackwardEdges(self):
627    print("Removing backward edges...")
628    in_order = 0
629    out_of_order = 0
630    lost_source = 0
631
632    for xf in self.transfers:
633      lost = 0
634      size = xf.src_ranges.size()
635      for u in xf.goes_before:
636        # xf should go before u
637        if xf.order < u.order:
638          # it does, hurray!
639          in_order += 1
640        else:
641          # it doesn't, boo.  trim the blocks that u writes from xf's
642          # source, so that xf can go after u.
643          out_of_order += 1
644          assert xf.src_ranges.overlaps(u.tgt_ranges)
645          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
646          xf.intact = False
647
648      if xf.style == "diff" and not xf.src_ranges:
649        # nothing left to diff from; treat as new data
650        xf.style = "new"
651
652      lost = size - xf.src_ranges.size()
653      lost_source += lost
654
655    print(("  %d/%d dependencies (%.2f%%) were violated; "
656           "%d source blocks removed.") %
657          (out_of_order, in_order + out_of_order,
658           (out_of_order * 100.0 / (in_order + out_of_order))
659           if (in_order + out_of_order) else 0.0,
660           lost_source))
661
662  def ReverseBackwardEdges(self):
663    print("Reversing backward edges...")
664    in_order = 0
665    out_of_order = 0
666    stashes = 0
667    stash_size = 0
668
669    for xf in self.transfers:
670      lost = 0
671      size = xf.src_ranges.size()
672      for u in xf.goes_before.copy():
673        # xf should go before u
674        if xf.order < u.order:
675          # it does, hurray!
676          in_order += 1
677        else:
678          # it doesn't, boo.  modify u to stash the blocks that it
679          # writes that xf wants to read, and then require u to go
680          # before xf.
681          out_of_order += 1
682
683          overlap = xf.src_ranges.intersect(u.tgt_ranges)
684          assert overlap
685
686          u.stash_before.append((stashes, overlap))
687          xf.use_stash.append((stashes, overlap))
688          stashes += 1
689          stash_size += overlap.size()
690
691          # reverse the edge direction; now xf must go after u
692          del xf.goes_before[u]
693          del u.goes_after[xf]
694          xf.goes_after[u] = None    # value doesn't matter
695          u.goes_before[xf] = None
696
697    print(("  %d/%d dependencies (%.2f%%) were violated; "
698           "%d source blocks stashed.") %
699          (out_of_order, in_order + out_of_order,
700           (out_of_order * 100.0 / (in_order + out_of_order))
701           if (in_order + out_of_order) else 0.0,
702           stash_size))
703
704  def FindVertexSequence(self):
705    print("Finding vertex sequence...")
706
707    # This is based on "A Fast & Effective Heuristic for the Feedback
708    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
709    # it as starting with the digraph G and moving all the vertices to
710    # be on a horizontal line in some order, trying to minimize the
711    # number of edges that end up pointing to the left.  Left-pointing
712    # edges will get removed to turn the digraph into a DAG.  In this
713    # case each edge has a weight which is the number of source blocks
714    # we'll lose if that edge is removed; we try to minimize the total
715    # weight rather than just the number of edges.
716
717    # Make a copy of the edge set; this copy will get destroyed by the
718    # algorithm.
719    for xf in self.transfers:
720      xf.incoming = xf.goes_after.copy()
721      xf.outgoing = xf.goes_before.copy()
722
723    # We use an OrderedDict instead of just a set so that the output
724    # is repeatable; otherwise it would depend on the hash values of
725    # the transfer objects.
726    G = OrderedDict()
727    for xf in self.transfers:
728      G[xf] = None
729    s1 = deque()  # the left side of the sequence, built from left to right
730    s2 = deque()  # the right side of the sequence, built from right to left
731
732    while G:
733
734      # Put all sinks at the end of the sequence.
735      while True:
736        sinks = [u for u in G if not u.outgoing]
737        if not sinks: break
738        for u in sinks:
739          s2.appendleft(u)
740          del G[u]
741          for iu in u.incoming:
742            del iu.outgoing[u]
743
744      # Put all the sources at the beginning of the sequence.
745      while True:
746        sources = [u for u in G if not u.incoming]
747        if not sources: break
748        for u in sources:
749          s1.append(u)
750          del G[u]
751          for iu in u.outgoing:
752            del iu.incoming[u]
753
754      if not G: break
755
756      # Find the "best" vertex to put next.  "Best" is the one that
757      # maximizes the net difference in source blocks saved we get by
758      # pretending it's a source rather than a sink.
759
760      max_d = None
761      best_u = None
762      for u in G:
763        d = sum(u.outgoing.values()) - sum(u.incoming.values())
764        if best_u is None or d > max_d:
765          max_d = d
766          best_u = u
767
768      u = best_u
769      s1.append(u)
770      del G[u]
771      for iu in u.outgoing:
772        del iu.incoming[u]
773      for iu in u.incoming:
774        del iu.outgoing[u]
775
776    # Now record the sequence in the 'order' field of each transfer,
777    # and by rearranging self.transfers to be in the chosen sequence.
778
779    new_transfers = []
780    for x in itertools.chain(s1, s2):
781      x.order = len(new_transfers)
782      new_transfers.append(x)
783      del x.incoming
784      del x.outgoing
785
786    self.transfers = new_transfers
787
788  def GenerateDigraph(self):
789    print("Generating digraph...")
790    for a in self.transfers:
791      for b in self.transfers:
792        if a is b: continue
793
794        # If the blocks written by A are read by B, then B needs to go before A.
795        i = a.tgt_ranges.intersect(b.src_ranges)
796        if i:
797          if b.src_name == "__ZERO":
798            # the cost of removing source blocks for the __ZERO domain
799            # is (nearly) zero.
800            size = 0
801          else:
802            size = i.size()
803          b.goes_before[a] = size
804          a.goes_after[b] = size
805
806  def FindTransfers(self):
807    self.transfers = []
808    empty = RangeSet()
809    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
810      if tgt_fn == "__ZERO":
811        # the special "__ZERO" domain is all the blocks not contained
812        # in any file and that are filled with zeros.  We have a
813        # special transfer style for zero blocks.
814        src_ranges = self.src.file_map.get("__ZERO", empty)
815        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
816                 "zero", self.transfers)
817        continue
818
819      elif tgt_fn in self.src.file_map:
820        # Look for an exact pathname match in the source.
821        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
822                 "diff", self.transfers)
823        continue
824
825      b = os.path.basename(tgt_fn)
826      if b in self.src_basenames:
827        # Look for an exact basename match in the source.
828        src_fn = self.src_basenames[b]
829        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
830                 "diff", self.transfers)
831        continue
832
833      b = re.sub("[0-9]+", "#", b)
834      if b in self.src_numpatterns:
835        # Look for a 'number pattern' match (a basename match after
836        # all runs of digits are replaced by "#").  (This is useful
837        # for .so files that contain version numbers in the filename
838        # that get bumped.)
839        src_fn = self.src_numpatterns[b]
840        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
841                 "diff", self.transfers)
842        continue
843
844      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
845
846  def AbbreviateSourceNames(self):
847    self.src_basenames = {}
848    self.src_numpatterns = {}
849
850    for k in self.src.file_map.keys():
851      b = os.path.basename(k)
852      self.src_basenames[b] = k
853      b = re.sub("[0-9]+", "#", b)
854      self.src_numpatterns[b] = k
855
856  @staticmethod
857  def AssertPartition(total, seq):
858    """Assert that all the RangeSets in 'seq' form a partition of the
859    'total' RangeSet (ie, they are nonintersecting and their union
860    equals 'total')."""
861    so_far = RangeSet()
862    for i in seq:
863      assert not so_far.overlaps(i)
864      so_far = so_far.union(i)
865    assert so_far == total
866