blockimgdiff.py revision 5fcaaeffc3d132f1223adf7ee8230b6815477af5
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import re
24import subprocess
25import threading
26import tempfile
27
28from rangelib import RangeSet
29
30
31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72
73class Image(object):
74  def ReadRangeSet(self, ranges):
75    raise NotImplementedError
76
77  def TotalSha1(self, include_clobbered_blocks=False):
78    raise NotImplementedError
79
80
81class EmptyImage(Image):
82  """A zero-length image."""
83  blocksize = 4096
84  care_map = RangeSet()
85  clobbered_blocks = RangeSet()
86  total_blocks = 0
87  file_map = {}
88  def ReadRangeSet(self, ranges):
89    return ()
90  def TotalSha1(self, include_clobbered_blocks=False):
91    # EmptyImage always carries empty clobbered_blocks, so
92    # include_clobbered_blocks can be ignored.
93    assert self.clobbered_blocks.size() == 0
94    return sha1().hexdigest()
95
96
97class DataImage(Image):
98  """An image wrapped around a single string of data."""
99
100  def __init__(self, data, trim=False, pad=False):
101    self.data = data
102    self.blocksize = 4096
103
104    assert not (trim and pad)
105
106    partial = len(self.data) % self.blocksize
107    if partial > 0:
108      if trim:
109        self.data = self.data[:-partial]
110      elif pad:
111        self.data += '\0' * (self.blocksize - partial)
112      else:
113        raise ValueError(("data for DataImage must be multiple of %d bytes "
114                          "unless trim or pad is specified") %
115                         (self.blocksize,))
116
117    assert len(self.data) % self.blocksize == 0
118
119    self.total_blocks = len(self.data) / self.blocksize
120    self.care_map = RangeSet(data=(0, self.total_blocks))
121    self.clobbered_blocks = RangeSet()
122
123    zero_blocks = []
124    nonzero_blocks = []
125    reference = '\0' * self.blocksize
126
127    for i in range(self.total_blocks):
128      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
129      if d == reference:
130        zero_blocks.append(i)
131        zero_blocks.append(i+1)
132      else:
133        nonzero_blocks.append(i)
134        nonzero_blocks.append(i+1)
135
136    self.file_map = {"__ZERO": RangeSet(zero_blocks),
137                     "__NONZERO": RangeSet(nonzero_blocks)}
138
139  def ReadRangeSet(self, ranges):
140    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
141
142  def TotalSha1(self, include_clobbered_blocks=False):
143    # DataImage always carries empty clobbered_blocks, so
144    # include_clobbered_blocks can be ignored.
145    assert self.clobbered_blocks.size() == 0
146    return sha1(self.data).hexdigest()
147
148
149class Transfer(object):
150  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
151    self.tgt_name = tgt_name
152    self.src_name = src_name
153    self.tgt_ranges = tgt_ranges
154    self.src_ranges = src_ranges
155    self.style = style
156    self.intact = (getattr(tgt_ranges, "monotonic", False) and
157                   getattr(src_ranges, "monotonic", False))
158
159    # We use OrderedDict rather than dict so that the output is repeatable;
160    # otherwise it would depend on the hash values of the Transfer objects.
161    self.goes_before = OrderedDict()
162    self.goes_after = OrderedDict()
163
164    self.stash_before = []
165    self.use_stash = []
166
167    self.id = len(by_id)
168    by_id.append(self)
169
170  def NetStashChange(self):
171    return (sum(sr.size() for (_, sr) in self.stash_before) -
172            sum(sr.size() for (_, sr) in self.use_stash))
173
174  def __str__(self):
175    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
176            " to " + str(self.tgt_ranges) + ">")
177
178
179# BlockImageDiff works on two image objects.  An image object is
180# anything that provides the following attributes:
181#
182#    blocksize: the size in bytes of a block, currently must be 4096.
183#
184#    total_blocks: the total size of the partition/image, in blocks.
185#
186#    care_map: a RangeSet containing which blocks (in the range [0,
187#      total_blocks) we actually care about; i.e. which blocks contain
188#      data.
189#
190#    file_map: a dict that partitions the blocks contained in care_map
191#      into smaller domains that are useful for doing diffs on.
192#      (Typically a domain is a file, and the key in file_map is the
193#      pathname.)
194#
195#    clobbered_blocks: a RangeSet containing which blocks contain data
196#      but may be altered by the FS. They need to be excluded when
197#      verifying the partition integrity.
198#
199#    ReadRangeSet(): a function that takes a RangeSet and returns the
200#      data contained in the image blocks of that RangeSet.  The data
201#      is returned as a list or tuple of strings; concatenating the
202#      elements together should produce the requested data.
203#      Implementations are free to break up the data into list/tuple
204#      elements in any way that is convenient.
205#
206#    TotalSha1(): a function that returns (as a hex string) the SHA-1
207#      hash of all the data in the image (ie, all the blocks in the
208#      care_map minus clobbered_blocks, or including the clobbered
209#      blocks if include_clobbered_blocks is True).
210#
211# When creating a BlockImageDiff, the src image may be None, in which
212# case the list of transfers produced will never read from the
213# original image.
214
215class BlockImageDiff(object):
216  def __init__(self, tgt, src=None, threads=None, version=3):
217    if threads is None:
218      threads = multiprocessing.cpu_count() // 2
219      if threads == 0:
220        threads = 1
221    self.threads = threads
222    self.version = version
223    self.transfers = []
224    self.src_basenames = {}
225    self.src_numpatterns = {}
226
227    assert version in (1, 2, 3)
228
229    self.tgt = tgt
230    if src is None:
231      src = EmptyImage()
232    self.src = src
233
234    # The updater code that installs the patch always uses 4k blocks.
235    assert tgt.blocksize == 4096
236    assert src.blocksize == 4096
237
238    # The range sets in each filemap should comprise a partition of
239    # the care map.
240    self.AssertPartition(src.care_map, src.file_map.values())
241    self.AssertPartition(tgt.care_map, tgt.file_map.values())
242
243  def Compute(self, prefix):
244    # When looking for a source file to use as the diff input for a
245    # target file, we try:
246    #   1) an exact path match if available, otherwise
247    #   2) a exact basename match if available, otherwise
248    #   3) a basename match after all runs of digits are replaced by
249    #      "#" if available, otherwise
250    #   4) we have no source for this target.
251    self.AbbreviateSourceNames()
252    self.FindTransfers()
253
254    # Find the ordering dependencies among transfers (this is O(n^2)
255    # in the number of transfers).
256    self.GenerateDigraph()
257    # Find a sequence of transfers that satisfies as many ordering
258    # dependencies as possible (heuristically).
259    self.FindVertexSequence()
260    # Fix up the ordering dependencies that the sequence didn't
261    # satisfy.
262    if self.version == 1:
263      self.RemoveBackwardEdges()
264    else:
265      self.ReverseBackwardEdges()
266      self.ImproveVertexSequence()
267
268    # Double-check our work.
269    self.AssertSequenceGood()
270
271    self.ComputePatches(prefix)
272    self.WriteTransfers(prefix)
273
274  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
275    data = source.ReadRangeSet(ranges)
276    ctx = sha1()
277
278    for p in data:
279      ctx.update(p)
280
281    return ctx.hexdigest()
282
283  def WriteTransfers(self, prefix):
284    out = []
285
286    total = 0
287    performs_read = False
288
289    stashes = {}
290    stashed_blocks = 0
291    max_stashed_blocks = 0
292
293    free_stash_ids = []
294    next_stash_id = 0
295
296    for xf in self.transfers:
297
298      if self.version < 2:
299        assert not xf.stash_before
300        assert not xf.use_stash
301
302      for s, sr in xf.stash_before:
303        assert s not in stashes
304        if free_stash_ids:
305          sid = heapq.heappop(free_stash_ids)
306        else:
307          sid = next_stash_id
308          next_stash_id += 1
309        stashes[s] = sid
310        stashed_blocks += sr.size()
311        if self.version == 2:
312          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
313        else:
314          sh = self.HashBlocks(self.src, sr)
315          if sh in stashes:
316            stashes[sh] += 1
317          else:
318            stashes[sh] = 1
319            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
320
321      if stashed_blocks > max_stashed_blocks:
322        max_stashed_blocks = stashed_blocks
323
324      free_string = []
325
326      if self.version == 1:
327        src_str = xf.src_ranges.to_string_raw()
328      elif self.version >= 2:
329
330        #   <# blocks> <src ranges>
331        #     OR
332        #   <# blocks> <src ranges> <src locs> <stash refs...>
333        #     OR
334        #   <# blocks> - <stash refs...>
335
336        size = xf.src_ranges.size()
337        src_str = [str(size)]
338
339        unstashed_src_ranges = xf.src_ranges
340        mapped_stashes = []
341        for s, sr in xf.use_stash:
342          sid = stashes.pop(s)
343          stashed_blocks -= sr.size()
344          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
345          sh = self.HashBlocks(self.src, sr)
346          sr = xf.src_ranges.map_within(sr)
347          mapped_stashes.append(sr)
348          if self.version == 2:
349            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
350          else:
351            assert sh in stashes
352            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
353            stashes[sh] -= 1
354            if stashes[sh] == 0:
355              free_string.append("free %s\n" % (sh))
356              stashes.pop(sh)
357          heapq.heappush(free_stash_ids, sid)
358
359        if unstashed_src_ranges:
360          src_str.insert(1, unstashed_src_ranges.to_string_raw())
361          if xf.use_stash:
362            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
363            src_str.insert(2, mapped_unstashed.to_string_raw())
364            mapped_stashes.append(mapped_unstashed)
365            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
366        else:
367          src_str.insert(1, "-")
368          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
369
370        src_str = " ".join(src_str)
371
372      # all versions:
373      #   zero <rangeset>
374      #   new <rangeset>
375      #   erase <rangeset>
376      #
377      # version 1:
378      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
379      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
380      #   move <src rangeset> <tgt rangeset>
381      #
382      # version 2:
383      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
384      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
385      #   move <tgt rangeset> <src_str>
386      #
387      # version 3:
388      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
389      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
390      #   move hash <tgt rangeset> <src_str>
391
392      tgt_size = xf.tgt_ranges.size()
393
394      if xf.style == "new":
395        assert xf.tgt_ranges
396        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
397        total += tgt_size
398      elif xf.style == "move":
399        performs_read = True
400        assert xf.tgt_ranges
401        assert xf.src_ranges.size() == tgt_size
402        if xf.src_ranges != xf.tgt_ranges:
403          if self.version == 1:
404            out.append("%s %s %s\n" % (
405                xf.style,
406                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
407          elif self.version == 2:
408            out.append("%s %s %s\n" % (
409                xf.style,
410                xf.tgt_ranges.to_string_raw(), src_str))
411          elif self.version >= 3:
412            # take into account automatic stashing of overlapping blocks
413            if xf.src_ranges.overlaps(xf.tgt_ranges):
414              temp_stash_usage = stashed_blocks + xf.src_ranges.size();
415              if temp_stash_usage > max_stashed_blocks:
416                max_stashed_blocks = temp_stash_usage
417
418            out.append("%s %s %s %s\n" % (
419                xf.style,
420                self.HashBlocks(self.tgt, xf.tgt_ranges),
421                xf.tgt_ranges.to_string_raw(), src_str))
422          total += tgt_size
423      elif xf.style in ("bsdiff", "imgdiff"):
424        performs_read = True
425        assert xf.tgt_ranges
426        assert xf.src_ranges
427        if self.version == 1:
428          out.append("%s %d %d %s %s\n" % (
429              xf.style, xf.patch_start, xf.patch_len,
430              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
431        elif self.version == 2:
432          out.append("%s %d %d %s %s\n" % (
433              xf.style, xf.patch_start, xf.patch_len,
434              xf.tgt_ranges.to_string_raw(), src_str))
435        elif self.version >= 3:
436          # take into account automatic stashing of overlapping blocks
437          if xf.src_ranges.overlaps(xf.tgt_ranges):
438            temp_stash_usage = stashed_blocks + xf.src_ranges.size();
439            if temp_stash_usage > max_stashed_blocks:
440              max_stashed_blocks = temp_stash_usage
441
442          out.append("%s %d %d %s %s %s %s\n" % (
443              xf.style,
444              xf.patch_start, xf.patch_len,
445              self.HashBlocks(self.src, xf.src_ranges),
446              self.HashBlocks(self.tgt, xf.tgt_ranges),
447              xf.tgt_ranges.to_string_raw(), src_str))
448        total += tgt_size
449      elif xf.style == "zero":
450        assert xf.tgt_ranges
451        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
452        if to_zero:
453          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
454          total += to_zero.size()
455      else:
456        raise ValueError("unknown transfer style '%s'\n" % xf.style)
457
458      if free_string:
459        out.append("".join(free_string))
460
461      # sanity check: abort if we're going to need more than 512 MB if
462      # stash space
463      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
464
465    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
466    if performs_read:
467      # if some of the original data is used, then at the end we'll
468      # erase all the blocks on the partition that don't contain data
469      # in the new image.
470      new_dontcare = all_tgt.subtract(self.tgt.care_map)
471      if new_dontcare:
472        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
473    else:
474      # if nothing is read (ie, this is a full OTA), then we can start
475      # by erasing the entire partition.
476      out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
477
478    out.insert(0, "%d\n" % (self.version,))   # format version number
479    out.insert(1, str(total) + "\n")
480    if self.version >= 2:
481      # version 2 only: after the total block count, we give the number
482      # of stash slots needed, and the maximum size needed (in blocks)
483      out.insert(2, str(next_stash_id) + "\n")
484      out.insert(3, str(max_stashed_blocks) + "\n")
485
486    with open(prefix + ".transfer.list", "wb") as f:
487      for i in out:
488        f.write(i)
489
490    if self.version >= 2:
491      print("max stashed blocks: %d  (%d bytes)\n" % (
492          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
493
494  def ComputePatches(self, prefix):
495    print("Reticulating splines...")
496    diff_q = []
497    patch_num = 0
498    with open(prefix + ".new.dat", "wb") as new_f:
499      for xf in self.transfers:
500        if xf.style == "zero":
501          pass
502        elif xf.style == "new":
503          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
504            new_f.write(piece)
505        elif xf.style == "diff":
506          src = self.src.ReadRangeSet(xf.src_ranges)
507          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
508
509          # We can't compare src and tgt directly because they may have
510          # the same content but be broken up into blocks differently, eg:
511          #
512          #    ["he", "llo"]  vs  ["h", "ello"]
513          #
514          # We want those to compare equal, ideally without having to
515          # actually concatenate the strings (these may be tens of
516          # megabytes).
517
518          src_sha1 = sha1()
519          for p in src:
520            src_sha1.update(p)
521          tgt_sha1 = sha1()
522          tgt_size = 0
523          for p in tgt:
524            tgt_sha1.update(p)
525            tgt_size += len(p)
526
527          if src_sha1.digest() == tgt_sha1.digest():
528            # These are identical; we don't need to generate a patch,
529            # just issue copy commands on the device.
530            xf.style = "move"
531          else:
532            # For files in zip format (eg, APKs, JARs, etc.) we would
533            # like to use imgdiff -z if possible (because it usually
534            # produces significantly smaller patches than bsdiff).
535            # This is permissible if:
536            #
537            #  - the source and target files are monotonic (ie, the
538            #    data is stored with blocks in increasing order), and
539            #  - we haven't removed any blocks from the source set.
540            #
541            # If these conditions are satisfied then appending all the
542            # blocks in the set together in order will produce a valid
543            # zip file (plus possibly extra zeros in the last block),
544            # which is what imgdiff needs to operate.  (imgdiff is
545            # fine with extra zeros at the end of the file.)
546            imgdiff = (xf.intact and
547                       xf.tgt_name.split(".")[-1].lower()
548                       in ("apk", "jar", "zip"))
549            xf.style = "imgdiff" if imgdiff else "bsdiff"
550            diff_q.append((tgt_size, src, tgt, xf, patch_num))
551            patch_num += 1
552
553        else:
554          assert False, "unknown style " + xf.style
555
556    if diff_q:
557      if self.threads > 1:
558        print("Computing patches (using %d threads)..." % (self.threads,))
559      else:
560        print("Computing patches...")
561      diff_q.sort()
562
563      patches = [None] * patch_num
564
565      # TODO: Rewrite with multiprocessing.ThreadPool?
566      lock = threading.Lock()
567      def diff_worker():
568        while True:
569          with lock:
570            if not diff_q:
571              return
572            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
573          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
574          size = len(patch)
575          with lock:
576            patches[patchnum] = (patch, xf)
577            print("%10d %10d (%6.2f%%) %7s %s" % (
578                size, tgt_size, size * 100.0 / tgt_size, xf.style,
579                xf.tgt_name if xf.tgt_name == xf.src_name else (
580                    xf.tgt_name + " (from " + xf.src_name + ")")))
581
582      threads = [threading.Thread(target=diff_worker)
583                 for _ in range(self.threads)]
584      for th in threads:
585        th.start()
586      while threads:
587        threads.pop().join()
588    else:
589      patches = []
590
591    p = 0
592    with open(prefix + ".patch.dat", "wb") as patch_f:
593      for patch, xf in patches:
594        xf.patch_start = p
595        xf.patch_len = len(patch)
596        patch_f.write(patch)
597        p += len(patch)
598
599  def AssertSequenceGood(self):
600    # Simulate the sequences of transfers we will output, and check that:
601    # - we never read a block after writing it, and
602    # - we write every block we care about exactly once.
603
604    # Start with no blocks having been touched yet.
605    touched = RangeSet()
606
607    # Imagine processing the transfers in order.
608    for xf in self.transfers:
609      # Check that the input blocks for this transfer haven't yet been touched.
610
611      x = xf.src_ranges
612      if self.version >= 2:
613        for _, sr in xf.use_stash:
614          x = x.subtract(sr)
615
616      assert not touched.overlaps(x)
617      # Check that the output blocks for this transfer haven't yet been touched.
618      assert not touched.overlaps(xf.tgt_ranges)
619      # Touch all the blocks written by this transfer.
620      touched = touched.union(xf.tgt_ranges)
621
622    # Check that we've written every target block.
623    assert touched == self.tgt.care_map
624
625  def ImproveVertexSequence(self):
626    print("Improving vertex order...")
627
628    # At this point our digraph is acyclic; we reversed any edges that
629    # were backwards in the heuristically-generated sequence.  The
630    # previously-generated order is still acceptable, but we hope to
631    # find a better order that needs less memory for stashed data.
632    # Now we do a topological sort to generate a new vertex order,
633    # using a greedy algorithm to choose which vertex goes next
634    # whenever we have a choice.
635
636    # Make a copy of the edge set; this copy will get destroyed by the
637    # algorithm.
638    for xf in self.transfers:
639      xf.incoming = xf.goes_after.copy()
640      xf.outgoing = xf.goes_before.copy()
641
642    L = []   # the new vertex order
643
644    # S is the set of sources in the remaining graph; we always choose
645    # the one that leaves the least amount of stashed data after it's
646    # executed.
647    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
648         if not u.incoming]
649    heapq.heapify(S)
650
651    while S:
652      _, _, xf = heapq.heappop(S)
653      L.append(xf)
654      for u in xf.outgoing:
655        del u.incoming[xf]
656        if not u.incoming:
657          heapq.heappush(S, (u.NetStashChange(), u.order, u))
658
659    # if this fails then our graph had a cycle.
660    assert len(L) == len(self.transfers)
661
662    self.transfers = L
663    for i, xf in enumerate(L):
664      xf.order = i
665
666  def RemoveBackwardEdges(self):
667    print("Removing backward edges...")
668    in_order = 0
669    out_of_order = 0
670    lost_source = 0
671
672    for xf in self.transfers:
673      lost = 0
674      size = xf.src_ranges.size()
675      for u in xf.goes_before:
676        # xf should go before u
677        if xf.order < u.order:
678          # it does, hurray!
679          in_order += 1
680        else:
681          # it doesn't, boo.  trim the blocks that u writes from xf's
682          # source, so that xf can go after u.
683          out_of_order += 1
684          assert xf.src_ranges.overlaps(u.tgt_ranges)
685          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
686          xf.intact = False
687
688      if xf.style == "diff" and not xf.src_ranges:
689        # nothing left to diff from; treat as new data
690        xf.style = "new"
691
692      lost = size - xf.src_ranges.size()
693      lost_source += lost
694
695    print(("  %d/%d dependencies (%.2f%%) were violated; "
696           "%d source blocks removed.") %
697          (out_of_order, in_order + out_of_order,
698           (out_of_order * 100.0 / (in_order + out_of_order))
699           if (in_order + out_of_order) else 0.0,
700           lost_source))
701
702  def ReverseBackwardEdges(self):
703    print("Reversing backward edges...")
704    in_order = 0
705    out_of_order = 0
706    stashes = 0
707    stash_size = 0
708
709    for xf in self.transfers:
710      for u in xf.goes_before.copy():
711        # xf should go before u
712        if xf.order < u.order:
713          # it does, hurray!
714          in_order += 1
715        else:
716          # it doesn't, boo.  modify u to stash the blocks that it
717          # writes that xf wants to read, and then require u to go
718          # before xf.
719          out_of_order += 1
720
721          overlap = xf.src_ranges.intersect(u.tgt_ranges)
722          assert overlap
723
724          u.stash_before.append((stashes, overlap))
725          xf.use_stash.append((stashes, overlap))
726          stashes += 1
727          stash_size += overlap.size()
728
729          # reverse the edge direction; now xf must go after u
730          del xf.goes_before[u]
731          del u.goes_after[xf]
732          xf.goes_after[u] = None    # value doesn't matter
733          u.goes_before[xf] = None
734
735    print(("  %d/%d dependencies (%.2f%%) were violated; "
736           "%d source blocks stashed.") %
737          (out_of_order, in_order + out_of_order,
738           (out_of_order * 100.0 / (in_order + out_of_order))
739           if (in_order + out_of_order) else 0.0,
740           stash_size))
741
742  def FindVertexSequence(self):
743    print("Finding vertex sequence...")
744
745    # This is based on "A Fast & Effective Heuristic for the Feedback
746    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
747    # it as starting with the digraph G and moving all the vertices to
748    # be on a horizontal line in some order, trying to minimize the
749    # number of edges that end up pointing to the left.  Left-pointing
750    # edges will get removed to turn the digraph into a DAG.  In this
751    # case each edge has a weight which is the number of source blocks
752    # we'll lose if that edge is removed; we try to minimize the total
753    # weight rather than just the number of edges.
754
755    # Make a copy of the edge set; this copy will get destroyed by the
756    # algorithm.
757    for xf in self.transfers:
758      xf.incoming = xf.goes_after.copy()
759      xf.outgoing = xf.goes_before.copy()
760
761    # We use an OrderedDict instead of just a set so that the output
762    # is repeatable; otherwise it would depend on the hash values of
763    # the transfer objects.
764    G = OrderedDict()
765    for xf in self.transfers:
766      G[xf] = None
767    s1 = deque()  # the left side of the sequence, built from left to right
768    s2 = deque()  # the right side of the sequence, built from right to left
769
770    while G:
771
772      # Put all sinks at the end of the sequence.
773      while True:
774        sinks = [u for u in G if not u.outgoing]
775        if not sinks:
776          break
777        for u in sinks:
778          s2.appendleft(u)
779          del G[u]
780          for iu in u.incoming:
781            del iu.outgoing[u]
782
783      # Put all the sources at the beginning of the sequence.
784      while True:
785        sources = [u for u in G if not u.incoming]
786        if not sources:
787          break
788        for u in sources:
789          s1.append(u)
790          del G[u]
791          for iu in u.outgoing:
792            del iu.incoming[u]
793
794      if not G:
795        break
796
797      # Find the "best" vertex to put next.  "Best" is the one that
798      # maximizes the net difference in source blocks saved we get by
799      # pretending it's a source rather than a sink.
800
801      max_d = None
802      best_u = None
803      for u in G:
804        d = sum(u.outgoing.values()) - sum(u.incoming.values())
805        if best_u is None or d > max_d:
806          max_d = d
807          best_u = u
808
809      u = best_u
810      s1.append(u)
811      del G[u]
812      for iu in u.outgoing:
813        del iu.incoming[u]
814      for iu in u.incoming:
815        del iu.outgoing[u]
816
817    # Now record the sequence in the 'order' field of each transfer,
818    # and by rearranging self.transfers to be in the chosen sequence.
819
820    new_transfers = []
821    for x in itertools.chain(s1, s2):
822      x.order = len(new_transfers)
823      new_transfers.append(x)
824      del x.incoming
825      del x.outgoing
826
827    self.transfers = new_transfers
828
829  def GenerateDigraph(self):
830    print("Generating digraph...")
831    for a in self.transfers:
832      for b in self.transfers:
833        if a is b:
834          continue
835
836        # If the blocks written by A are read by B, then B needs to go before A.
837        i = a.tgt_ranges.intersect(b.src_ranges)
838        if i:
839          if b.src_name == "__ZERO":
840            # the cost of removing source blocks for the __ZERO domain
841            # is (nearly) zero.
842            size = 0
843          else:
844            size = i.size()
845          b.goes_before[a] = size
846          a.goes_after[b] = size
847
848  def FindTransfers(self):
849    empty = RangeSet()
850    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
851      if tgt_fn == "__ZERO":
852        # the special "__ZERO" domain is all the blocks not contained
853        # in any file and that are filled with zeros.  We have a
854        # special transfer style for zero blocks.
855        src_ranges = self.src.file_map.get("__ZERO", empty)
856        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
857                 "zero", self.transfers)
858        continue
859
860      elif tgt_fn == "__COPY":
861        # "__COPY" domain includes all the blocks not contained in any
862        # file and that need to be copied unconditionally to the target.
863        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
864        continue
865
866      elif tgt_fn in self.src.file_map:
867        # Look for an exact pathname match in the source.
868        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
869                 "diff", self.transfers)
870        continue
871
872      b = os.path.basename(tgt_fn)
873      if b in self.src_basenames:
874        # Look for an exact basename match in the source.
875        src_fn = self.src_basenames[b]
876        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
877                 "diff", self.transfers)
878        continue
879
880      b = re.sub("[0-9]+", "#", b)
881      if b in self.src_numpatterns:
882        # Look for a 'number pattern' match (a basename match after
883        # all runs of digits are replaced by "#").  (This is useful
884        # for .so files that contain version numbers in the filename
885        # that get bumped.)
886        src_fn = self.src_numpatterns[b]
887        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
888                 "diff", self.transfers)
889        continue
890
891      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
892
893  def AbbreviateSourceNames(self):
894    for k in self.src.file_map.keys():
895      b = os.path.basename(k)
896      self.src_basenames[b] = k
897      b = re.sub("[0-9]+", "#", b)
898      self.src_numpatterns[b] = k
899
900  @staticmethod
901  def AssertPartition(total, seq):
902    """Assert that all the RangeSets in 'seq' form a partition of the
903    'total' RangeSet (ie, they are nonintersecting and their union
904    equals 'total')."""
905    so_far = RangeSet()
906    for i in seq:
907      assert not so_far.overlaps(i)
908      so_far = so_far.union(i)
909    assert so_far == total
910