blockimgdiff.py revision 29f529f33e5288efd8823cdc935260cfb0d8c7fa
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import re
24import subprocess
25import threading
26import tempfile
27
28from rangelib import RangeSet
29
30
31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72
73class Image(object):
74  def ReadRangeSet(self, ranges):
75    raise NotImplementedError
76
77  def TotalSha1(self):
78    raise NotImplementedError
79
80
81class EmptyImage(Image):
82  """A zero-length image."""
83  blocksize = 4096
84  care_map = RangeSet()
85  total_blocks = 0
86  file_map = {}
87  def ReadRangeSet(self, ranges):
88    return ()
89  def TotalSha1(self):
90    return sha1().hexdigest()
91
92
93class DataImage(Image):
94  """An image wrapped around a single string of data."""
95
96  def __init__(self, data, trim=False, pad=False):
97    self.data = data
98    self.blocksize = 4096
99
100    assert not (trim and pad)
101
102    partial = len(self.data) % self.blocksize
103    if partial > 0:
104      if trim:
105        self.data = self.data[:-partial]
106      elif pad:
107        self.data += '\0' * (self.blocksize - partial)
108      else:
109        raise ValueError(("data for DataImage must be multiple of %d bytes "
110                          "unless trim or pad is specified") %
111                         (self.blocksize,))
112
113    assert len(self.data) % self.blocksize == 0
114
115    self.total_blocks = len(self.data) / self.blocksize
116    self.care_map = RangeSet(data=(0, self.total_blocks))
117
118    zero_blocks = []
119    nonzero_blocks = []
120    reference = '\0' * self.blocksize
121
122    for i in range(self.total_blocks):
123      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
124      if d == reference:
125        zero_blocks.append(i)
126        zero_blocks.append(i+1)
127      else:
128        nonzero_blocks.append(i)
129        nonzero_blocks.append(i+1)
130
131    self.file_map = {"__ZERO": RangeSet(zero_blocks),
132                     "__NONZERO": RangeSet(nonzero_blocks)}
133
134  def ReadRangeSet(self, ranges):
135    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
136
137  def TotalSha1(self):
138    return sha1(self.data).hexdigest()
139
140
141class Transfer(object):
142  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
143    self.tgt_name = tgt_name
144    self.src_name = src_name
145    self.tgt_ranges = tgt_ranges
146    self.src_ranges = src_ranges
147    self.style = style
148    self.intact = (getattr(tgt_ranges, "monotonic", False) and
149                   getattr(src_ranges, "monotonic", False))
150
151    # We use OrderedDict rather than dict so that the output is repeatable;
152    # otherwise it would depend on the hash values of the Transfer objects.
153    self.goes_before = OrderedDict()
154    self.goes_after = OrderedDict()
155
156    self.stash_before = []
157    self.use_stash = []
158
159    self.id = len(by_id)
160    by_id.append(self)
161
162  def NetStashChange(self):
163    return (sum(sr.size() for (_, sr) in self.stash_before) -
164            sum(sr.size() for (_, sr) in self.use_stash))
165
166  def __str__(self):
167    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
168            " to " + str(self.tgt_ranges) + ">")
169
170
171# BlockImageDiff works on two image objects.  An image object is
172# anything that provides the following attributes:
173#
174#    blocksize: the size in bytes of a block, currently must be 4096.
175#
176#    total_blocks: the total size of the partition/image, in blocks.
177#
178#    care_map: a RangeSet containing which blocks (in the range [0,
179#      total_blocks) we actually care about; i.e. which blocks contain
180#      data.
181#
182#    file_map: a dict that partitions the blocks contained in care_map
183#      into smaller domains that are useful for doing diffs on.
184#      (Typically a domain is a file, and the key in file_map is the
185#      pathname.)
186#
187#    ReadRangeSet(): a function that takes a RangeSet and returns the
188#      data contained in the image blocks of that RangeSet.  The data
189#      is returned as a list or tuple of strings; concatenating the
190#      elements together should produce the requested data.
191#      Implementations are free to break up the data into list/tuple
192#      elements in any way that is convenient.
193#
194#    TotalSha1(): a function that returns (as a hex string) the SHA-1
195#      hash of all the data in the image (ie, all the blocks in the
196#      care_map)
197#
198# When creating a BlockImageDiff, the src image may be None, in which
199# case the list of transfers produced will never read from the
200# original image.
201
202class BlockImageDiff(object):
203  def __init__(self, tgt, src=None, threads=None, version=3):
204    if threads is None:
205      threads = multiprocessing.cpu_count() // 2
206      if threads == 0:
207        threads = 1
208    self.threads = threads
209    self.version = version
210    self.transfers = []
211    self.src_basenames = {}
212    self.src_numpatterns = {}
213
214    assert version in (1, 2, 3)
215
216    self.tgt = tgt
217    if src is None:
218      src = EmptyImage()
219    self.src = src
220
221    # The updater code that installs the patch always uses 4k blocks.
222    assert tgt.blocksize == 4096
223    assert src.blocksize == 4096
224
225    # The range sets in each filemap should comprise a partition of
226    # the care map.
227    self.AssertPartition(src.care_map, src.file_map.values())
228    self.AssertPartition(tgt.care_map, tgt.file_map.values())
229
230  def Compute(self, prefix):
231    # When looking for a source file to use as the diff input for a
232    # target file, we try:
233    #   1) an exact path match if available, otherwise
234    #   2) a exact basename match if available, otherwise
235    #   3) a basename match after all runs of digits are replaced by
236    #      "#" if available, otherwise
237    #   4) we have no source for this target.
238    self.AbbreviateSourceNames()
239    self.FindTransfers()
240
241    # Find the ordering dependencies among transfers (this is O(n^2)
242    # in the number of transfers).
243    self.GenerateDigraph()
244    # Find a sequence of transfers that satisfies as many ordering
245    # dependencies as possible (heuristically).
246    self.FindVertexSequence()
247    # Fix up the ordering dependencies that the sequence didn't
248    # satisfy.
249    if self.version == 1:
250      self.RemoveBackwardEdges()
251    else:
252      self.ReverseBackwardEdges()
253      self.ImproveVertexSequence()
254
255    # Double-check our work.
256    self.AssertSequenceGood()
257
258    self.ComputePatches(prefix)
259    self.WriteTransfers(prefix)
260
261  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
262    data = source.ReadRangeSet(ranges)
263    ctx = sha1()
264
265    for p in data:
266      ctx.update(p)
267
268    return ctx.hexdigest()
269
270  def WriteTransfers(self, prefix):
271    out = []
272
273    total = 0
274    performs_read = False
275
276    stashes = {}
277    stashed_blocks = 0
278    max_stashed_blocks = 0
279
280    free_stash_ids = []
281    next_stash_id = 0
282
283    for xf in self.transfers:
284
285      if self.version < 2:
286        assert not xf.stash_before
287        assert not xf.use_stash
288
289      for s, sr in xf.stash_before:
290        assert s not in stashes
291        if free_stash_ids:
292          sid = heapq.heappop(free_stash_ids)
293        else:
294          sid = next_stash_id
295          next_stash_id += 1
296        stashes[s] = sid
297        stashed_blocks += sr.size()
298        if self.version == 2:
299          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
300        else:
301          sh = self.HashBlocks(self.src, sr)
302          if sh in stashes:
303            stashes[sh] += 1
304          else:
305            stashes[sh] = 1
306            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
307
308      if stashed_blocks > max_stashed_blocks:
309        max_stashed_blocks = stashed_blocks
310
311      free_string = []
312
313      if self.version == 1:
314        src_str = xf.src_ranges.to_string_raw()
315      elif self.version >= 2:
316
317        #   <# blocks> <src ranges>
318        #     OR
319        #   <# blocks> <src ranges> <src locs> <stash refs...>
320        #     OR
321        #   <# blocks> - <stash refs...>
322
323        size = xf.src_ranges.size()
324        src_str = [str(size)]
325
326        unstashed_src_ranges = xf.src_ranges
327        mapped_stashes = []
328        for s, sr in xf.use_stash:
329          sid = stashes.pop(s)
330          stashed_blocks -= sr.size()
331          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
332          sh = self.HashBlocks(self.src, sr)
333          sr = xf.src_ranges.map_within(sr)
334          mapped_stashes.append(sr)
335          if self.version == 2:
336            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
337          else:
338            assert sh in stashes
339            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
340            stashes[sh] -= 1
341            if stashes[sh] == 0:
342              free_string.append("free %s\n" % (sh))
343              stashes.pop(sh)
344          heapq.heappush(free_stash_ids, sid)
345
346        if unstashed_src_ranges:
347          src_str.insert(1, unstashed_src_ranges.to_string_raw())
348          if xf.use_stash:
349            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
350            src_str.insert(2, mapped_unstashed.to_string_raw())
351            mapped_stashes.append(mapped_unstashed)
352            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
353        else:
354          src_str.insert(1, "-")
355          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
356
357        src_str = " ".join(src_str)
358
359      # all versions:
360      #   zero <rangeset>
361      #   new <rangeset>
362      #   erase <rangeset>
363      #
364      # version 1:
365      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
366      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
367      #   move <src rangeset> <tgt rangeset>
368      #
369      # version 2:
370      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
371      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
372      #   move <tgt rangeset> <src_str>
373      #
374      # version 3:
375      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
376      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
377      #   move hash <tgt rangeset> <src_str>
378
379      tgt_size = xf.tgt_ranges.size()
380
381      if xf.style == "new":
382        assert xf.tgt_ranges
383        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
384        total += tgt_size
385      elif xf.style == "move":
386        performs_read = True
387        assert xf.tgt_ranges
388        assert xf.src_ranges.size() == tgt_size
389        if xf.src_ranges != xf.tgt_ranges:
390          if self.version == 1:
391            out.append("%s %s %s\n" % (
392                xf.style,
393                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
394          elif self.version == 2:
395            out.append("%s %s %s\n" % (
396                xf.style,
397                xf.tgt_ranges.to_string_raw(), src_str))
398          elif self.version >= 3:
399            # take into account automatic stashing of overlapping blocks
400            if xf.src_ranges.overlaps(xf.tgt_ranges):
401              temp_stash_usage = stashed_blocks + xf.src_ranges.size();
402              if temp_stash_usage > max_stashed_blocks:
403                max_stashed_blocks = temp_stash_usage
404
405            out.append("%s %s %s %s\n" % (
406                xf.style,
407                self.HashBlocks(self.tgt, xf.tgt_ranges),
408                xf.tgt_ranges.to_string_raw(), src_str))
409          total += tgt_size
410      elif xf.style in ("bsdiff", "imgdiff"):
411        performs_read = True
412        assert xf.tgt_ranges
413        assert xf.src_ranges
414        if self.version == 1:
415          out.append("%s %d %d %s %s\n" % (
416              xf.style, xf.patch_start, xf.patch_len,
417              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
418        elif self.version == 2:
419          out.append("%s %d %d %s %s\n" % (
420              xf.style, xf.patch_start, xf.patch_len,
421              xf.tgt_ranges.to_string_raw(), src_str))
422        elif self.version >= 3:
423          # take into account automatic stashing of overlapping blocks
424          if xf.src_ranges.overlaps(xf.tgt_ranges):
425            temp_stash_usage = stashed_blocks + xf.src_ranges.size();
426            if temp_stash_usage > max_stashed_blocks:
427              max_stashed_blocks = temp_stash_usage
428
429          out.append("%s %d %d %s %s %s %s\n" % (
430              xf.style,
431              xf.patch_start, xf.patch_len,
432              self.HashBlocks(self.src, xf.src_ranges),
433              self.HashBlocks(self.tgt, xf.tgt_ranges),
434              xf.tgt_ranges.to_string_raw(), src_str))
435        total += tgt_size
436      elif xf.style == "zero":
437        assert xf.tgt_ranges
438        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
439        if to_zero:
440          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
441          total += to_zero.size()
442      else:
443        raise ValueError("unknown transfer style '%s'\n" % xf.style)
444
445      if free_string:
446        out.append("".join(free_string))
447
448
449      # sanity check: abort if we're going to need more than 512 MB if
450      # stash space
451      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
452
453    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
454    if performs_read:
455      # if some of the original data is used, then at the end we'll
456      # erase all the blocks on the partition that don't contain data
457      # in the new image.
458      new_dontcare = all_tgt.subtract(self.tgt.care_map)
459      if new_dontcare:
460        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
461    else:
462      # if nothing is read (ie, this is a full OTA), then we can start
463      # by erasing the entire partition.
464      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
465
466    out.insert(0, "%d\n" % (self.version,))   # format version number
467    out.insert(1, str(total) + "\n")
468    if self.version >= 2:
469      # version 2 only: after the total block count, we give the number
470      # of stash slots needed, and the maximum size needed (in blocks)
471      out.insert(2, str(next_stash_id) + "\n")
472      out.insert(3, str(max_stashed_blocks) + "\n")
473
474    with open(prefix + ".transfer.list", "wb") as f:
475      for i in out:
476        f.write(i)
477
478    if self.version >= 2:
479      print("max stashed blocks: %d  (%d bytes)\n" % (
480          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
481
482  def ComputePatches(self, prefix):
483    print("Reticulating splines...")
484    diff_q = []
485    patch_num = 0
486    with open(prefix + ".new.dat", "wb") as new_f:
487      for xf in self.transfers:
488        if xf.style == "zero":
489          pass
490        elif xf.style == "new":
491          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
492            new_f.write(piece)
493        elif xf.style == "diff":
494          src = self.src.ReadRangeSet(xf.src_ranges)
495          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
496
497          # We can't compare src and tgt directly because they may have
498          # the same content but be broken up into blocks differently, eg:
499          #
500          #    ["he", "llo"]  vs  ["h", "ello"]
501          #
502          # We want those to compare equal, ideally without having to
503          # actually concatenate the strings (these may be tens of
504          # megabytes).
505
506          src_sha1 = sha1()
507          for p in src:
508            src_sha1.update(p)
509          tgt_sha1 = sha1()
510          tgt_size = 0
511          for p in tgt:
512            tgt_sha1.update(p)
513            tgt_size += len(p)
514
515          if src_sha1.digest() == tgt_sha1.digest():
516            # These are identical; we don't need to generate a patch,
517            # just issue copy commands on the device.
518            xf.style = "move"
519          else:
520            # For files in zip format (eg, APKs, JARs, etc.) we would
521            # like to use imgdiff -z if possible (because it usually
522            # produces significantly smaller patches than bsdiff).
523            # This is permissible if:
524            #
525            #  - the source and target files are monotonic (ie, the
526            #    data is stored with blocks in increasing order), and
527            #  - we haven't removed any blocks from the source set.
528            #
529            # If these conditions are satisfied then appending all the
530            # blocks in the set together in order will produce a valid
531            # zip file (plus possibly extra zeros in the last block),
532            # which is what imgdiff needs to operate.  (imgdiff is
533            # fine with extra zeros at the end of the file.)
534            imgdiff = (xf.intact and
535                       xf.tgt_name.split(".")[-1].lower()
536                       in ("apk", "jar", "zip"))
537            xf.style = "imgdiff" if imgdiff else "bsdiff"
538            diff_q.append((tgt_size, src, tgt, xf, patch_num))
539            patch_num += 1
540
541        else:
542          assert False, "unknown style " + xf.style
543
544    if diff_q:
545      if self.threads > 1:
546        print("Computing patches (using %d threads)..." % (self.threads,))
547      else:
548        print("Computing patches...")
549      diff_q.sort()
550
551      patches = [None] * patch_num
552
553      # TODO: Rewrite with multiprocessing.ThreadPool?
554      lock = threading.Lock()
555      def diff_worker():
556        while True:
557          with lock:
558            if not diff_q:
559              return
560            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
561          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
562          size = len(patch)
563          with lock:
564            patches[patchnum] = (patch, xf)
565            print("%10d %10d (%6.2f%%) %7s %s" % (
566                size, tgt_size, size * 100.0 / tgt_size, xf.style,
567                xf.tgt_name if xf.tgt_name == xf.src_name else (
568                    xf.tgt_name + " (from " + xf.src_name + ")")))
569
570      threads = [threading.Thread(target=diff_worker)
571                 for _ in range(self.threads)]
572      for th in threads:
573        th.start()
574      while threads:
575        threads.pop().join()
576    else:
577      patches = []
578
579    p = 0
580    with open(prefix + ".patch.dat", "wb") as patch_f:
581      for patch, xf in patches:
582        xf.patch_start = p
583        xf.patch_len = len(patch)
584        patch_f.write(patch)
585        p += len(patch)
586
587  def AssertSequenceGood(self):
588    # Simulate the sequences of transfers we will output, and check that:
589    # - we never read a block after writing it, and
590    # - we write every block we care about exactly once.
591
592    # Start with no blocks having been touched yet.
593    touched = RangeSet()
594
595    # Imagine processing the transfers in order.
596    for xf in self.transfers:
597      # Check that the input blocks for this transfer haven't yet been touched.
598
599      x = xf.src_ranges
600      if self.version >= 2:
601        for _, sr in xf.use_stash:
602          x = x.subtract(sr)
603
604      assert not touched.overlaps(x)
605      # Check that the output blocks for this transfer haven't yet been touched.
606      assert not touched.overlaps(xf.tgt_ranges)
607      # Touch all the blocks written by this transfer.
608      touched = touched.union(xf.tgt_ranges)
609
610    # Check that we've written every target block.
611    assert touched == self.tgt.care_map
612
613  def ImproveVertexSequence(self):
614    print("Improving vertex order...")
615
616    # At this point our digraph is acyclic; we reversed any edges that
617    # were backwards in the heuristically-generated sequence.  The
618    # previously-generated order is still acceptable, but we hope to
619    # find a better order that needs less memory for stashed data.
620    # Now we do a topological sort to generate a new vertex order,
621    # using a greedy algorithm to choose which vertex goes next
622    # whenever we have a choice.
623
624    # Make a copy of the edge set; this copy will get destroyed by the
625    # algorithm.
626    for xf in self.transfers:
627      xf.incoming = xf.goes_after.copy()
628      xf.outgoing = xf.goes_before.copy()
629
630    L = []   # the new vertex order
631
632    # S is the set of sources in the remaining graph; we always choose
633    # the one that leaves the least amount of stashed data after it's
634    # executed.
635    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
636         if not u.incoming]
637    heapq.heapify(S)
638
639    while S:
640      _, _, xf = heapq.heappop(S)
641      L.append(xf)
642      for u in xf.outgoing:
643        del u.incoming[xf]
644        if not u.incoming:
645          heapq.heappush(S, (u.NetStashChange(), u.order, u))
646
647    # if this fails then our graph had a cycle.
648    assert len(L) == len(self.transfers)
649
650    self.transfers = L
651    for i, xf in enumerate(L):
652      xf.order = i
653
654  def RemoveBackwardEdges(self):
655    print("Removing backward edges...")
656    in_order = 0
657    out_of_order = 0
658    lost_source = 0
659
660    for xf in self.transfers:
661      lost = 0
662      size = xf.src_ranges.size()
663      for u in xf.goes_before:
664        # xf should go before u
665        if xf.order < u.order:
666          # it does, hurray!
667          in_order += 1
668        else:
669          # it doesn't, boo.  trim the blocks that u writes from xf's
670          # source, so that xf can go after u.
671          out_of_order += 1
672          assert xf.src_ranges.overlaps(u.tgt_ranges)
673          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
674          xf.intact = False
675
676      if xf.style == "diff" and not xf.src_ranges:
677        # nothing left to diff from; treat as new data
678        xf.style = "new"
679
680      lost = size - xf.src_ranges.size()
681      lost_source += lost
682
683    print(("  %d/%d dependencies (%.2f%%) were violated; "
684           "%d source blocks removed.") %
685          (out_of_order, in_order + out_of_order,
686           (out_of_order * 100.0 / (in_order + out_of_order))
687           if (in_order + out_of_order) else 0.0,
688           lost_source))
689
690  def ReverseBackwardEdges(self):
691    print("Reversing backward edges...")
692    in_order = 0
693    out_of_order = 0
694    stashes = 0
695    stash_size = 0
696
697    for xf in self.transfers:
698      for u in xf.goes_before.copy():
699        # xf should go before u
700        if xf.order < u.order:
701          # it does, hurray!
702          in_order += 1
703        else:
704          # it doesn't, boo.  modify u to stash the blocks that it
705          # writes that xf wants to read, and then require u to go
706          # before xf.
707          out_of_order += 1
708
709          overlap = xf.src_ranges.intersect(u.tgt_ranges)
710          assert overlap
711
712          u.stash_before.append((stashes, overlap))
713          xf.use_stash.append((stashes, overlap))
714          stashes += 1
715          stash_size += overlap.size()
716
717          # reverse the edge direction; now xf must go after u
718          del xf.goes_before[u]
719          del u.goes_after[xf]
720          xf.goes_after[u] = None    # value doesn't matter
721          u.goes_before[xf] = None
722
723    print(("  %d/%d dependencies (%.2f%%) were violated; "
724           "%d source blocks stashed.") %
725          (out_of_order, in_order + out_of_order,
726           (out_of_order * 100.0 / (in_order + out_of_order))
727           if (in_order + out_of_order) else 0.0,
728           stash_size))
729
730  def FindVertexSequence(self):
731    print("Finding vertex sequence...")
732
733    # This is based on "A Fast & Effective Heuristic for the Feedback
734    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
735    # it as starting with the digraph G and moving all the vertices to
736    # be on a horizontal line in some order, trying to minimize the
737    # number of edges that end up pointing to the left.  Left-pointing
738    # edges will get removed to turn the digraph into a DAG.  In this
739    # case each edge has a weight which is the number of source blocks
740    # we'll lose if that edge is removed; we try to minimize the total
741    # weight rather than just the number of edges.
742
743    # Make a copy of the edge set; this copy will get destroyed by the
744    # algorithm.
745    for xf in self.transfers:
746      xf.incoming = xf.goes_after.copy()
747      xf.outgoing = xf.goes_before.copy()
748
749    # We use an OrderedDict instead of just a set so that the output
750    # is repeatable; otherwise it would depend on the hash values of
751    # the transfer objects.
752    G = OrderedDict()
753    for xf in self.transfers:
754      G[xf] = None
755    s1 = deque()  # the left side of the sequence, built from left to right
756    s2 = deque()  # the right side of the sequence, built from right to left
757
758    while G:
759
760      # Put all sinks at the end of the sequence.
761      while True:
762        sinks = [u for u in G if not u.outgoing]
763        if not sinks:
764          break
765        for u in sinks:
766          s2.appendleft(u)
767          del G[u]
768          for iu in u.incoming:
769            del iu.outgoing[u]
770
771      # Put all the sources at the beginning of the sequence.
772      while True:
773        sources = [u for u in G if not u.incoming]
774        if not sources:
775          break
776        for u in sources:
777          s1.append(u)
778          del G[u]
779          for iu in u.outgoing:
780            del iu.incoming[u]
781
782      if not G:
783        break
784
785      # Find the "best" vertex to put next.  "Best" is the one that
786      # maximizes the net difference in source blocks saved we get by
787      # pretending it's a source rather than a sink.
788
789      max_d = None
790      best_u = None
791      for u in G:
792        d = sum(u.outgoing.values()) - sum(u.incoming.values())
793        if best_u is None or d > max_d:
794          max_d = d
795          best_u = u
796
797      u = best_u
798      s1.append(u)
799      del G[u]
800      for iu in u.outgoing:
801        del iu.incoming[u]
802      for iu in u.incoming:
803        del iu.outgoing[u]
804
805    # Now record the sequence in the 'order' field of each transfer,
806    # and by rearranging self.transfers to be in the chosen sequence.
807
808    new_transfers = []
809    for x in itertools.chain(s1, s2):
810      x.order = len(new_transfers)
811      new_transfers.append(x)
812      del x.incoming
813      del x.outgoing
814
815    self.transfers = new_transfers
816
817  def GenerateDigraph(self):
818    print("Generating digraph...")
819    for a in self.transfers:
820      for b in self.transfers:
821        if a is b:
822          continue
823
824        # If the blocks written by A are read by B, then B needs to go before A.
825        i = a.tgt_ranges.intersect(b.src_ranges)
826        if i:
827          if b.src_name == "__ZERO":
828            # the cost of removing source blocks for the __ZERO domain
829            # is (nearly) zero.
830            size = 0
831          else:
832            size = i.size()
833          b.goes_before[a] = size
834          a.goes_after[b] = size
835
836  def FindTransfers(self):
837    empty = RangeSet()
838    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
839      if tgt_fn == "__ZERO":
840        # the special "__ZERO" domain is all the blocks not contained
841        # in any file and that are filled with zeros.  We have a
842        # special transfer style for zero blocks.
843        src_ranges = self.src.file_map.get("__ZERO", empty)
844        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
845                 "zero", self.transfers)
846        continue
847
848      elif tgt_fn in self.src.file_map:
849        # Look for an exact pathname match in the source.
850        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
851                 "diff", self.transfers)
852        continue
853
854      b = os.path.basename(tgt_fn)
855      if b in self.src_basenames:
856        # Look for an exact basename match in the source.
857        src_fn = self.src_basenames[b]
858        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
859                 "diff", self.transfers)
860        continue
861
862      b = re.sub("[0-9]+", "#", b)
863      if b in self.src_numpatterns:
864        # Look for a 'number pattern' match (a basename match after
865        # all runs of digits are replaced by "#").  (This is useful
866        # for .so files that contain version numbers in the filename
867        # that get bumped.)
868        src_fn = self.src_numpatterns[b]
869        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
870                 "diff", self.transfers)
871        continue
872
873      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
874
875  def AbbreviateSourceNames(self):
876    for k in self.src.file_map.keys():
877      b = os.path.basename(k)
878      self.src_basenames[b] = k
879      b = re.sub("[0-9]+", "#", b)
880      self.src_numpatterns[b] = k
881
882  @staticmethod
883  def AssertPartition(total, seq):
884    """Assert that all the RangeSets in 'seq' form a partition of the
885    'total' RangeSet (ie, they are nonintersecting and their union
886    equals 'total')."""
887    so_far = RangeSet()
888    for i in seq:
889      assert not so_far.overlaps(i)
890      so_far = so_far.union(i)
891    assert so_far == total
892