blockimgdiff.py revision 937847ae493d36b9741f6387a2357d5cdceda3d9
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import common
20import heapq
21import itertools
22import multiprocessing
23import os
24import re
25import subprocess
26import threading
27import tempfile
28
29from rangelib import RangeSet
30
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34
35def compute_patch(src, tgt, imgdiff=False):
36  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
37  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
38  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
39  os.close(patchfd)
40
41  try:
42    with os.fdopen(srcfd, "wb") as f_src:
43      for p in src:
44        f_src.write(p)
45
46    with os.fdopen(tgtfd, "wb") as f_tgt:
47      for p in tgt:
48        f_tgt.write(p)
49    try:
50      os.unlink(patchfile)
51    except OSError:
52      pass
53    if imgdiff:
54      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
55                          stdout=open("/dev/null", "a"),
56                          stderr=subprocess.STDOUT)
57    else:
58      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
59
60    if p:
61      raise ValueError("diff failed: " + str(p))
62
63    with open(patchfile, "rb") as f:
64      return f.read()
65  finally:
66    try:
67      os.unlink(srcfile)
68      os.unlink(tgtfile)
69      os.unlink(patchfile)
70    except OSError:
71      pass
72
73
74class Image(object):
75  def ReadRangeSet(self, ranges):
76    raise NotImplementedError
77
78  def TotalSha1(self, include_clobbered_blocks=False):
79    raise NotImplementedError
80
81
82class EmptyImage(Image):
83  """A zero-length image."""
84  blocksize = 4096
85  care_map = RangeSet()
86  clobbered_blocks = RangeSet()
87  extended = RangeSet()
88  total_blocks = 0
89  file_map = {}
90  def ReadRangeSet(self, ranges):
91    return ()
92  def TotalSha1(self, include_clobbered_blocks=False):
93    # EmptyImage always carries empty clobbered_blocks, so
94    # include_clobbered_blocks can be ignored.
95    assert self.clobbered_blocks.size() == 0
96    return sha1().hexdigest()
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    if partial > 0:
110      if trim:
111        self.data = self.data[:-partial]
112      elif pad:
113        self.data += '\0' * (self.blocksize - partial)
114      else:
115        raise ValueError(("data for DataImage must be multiple of %d bytes "
116                          "unless trim or pad is specified") %
117                         (self.blocksize,))
118
119    assert len(self.data) % self.blocksize == 0
120
121    self.total_blocks = len(self.data) / self.blocksize
122    self.care_map = RangeSet(data=(0, self.total_blocks))
123    self.clobbered_blocks = RangeSet()
124    self.extended = RangeSet()
125
126    zero_blocks = []
127    nonzero_blocks = []
128    reference = '\0' * self.blocksize
129
130    for i in range(self.total_blocks):
131      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
132      if d == reference:
133        zero_blocks.append(i)
134        zero_blocks.append(i+1)
135      else:
136        nonzero_blocks.append(i)
137        nonzero_blocks.append(i+1)
138
139    self.file_map = {"__ZERO": RangeSet(zero_blocks),
140                     "__NONZERO": RangeSet(nonzero_blocks)}
141
142  def ReadRangeSet(self, ranges):
143    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
144
145  def TotalSha1(self, include_clobbered_blocks=False):
146    # DataImage always carries empty clobbered_blocks, so
147    # include_clobbered_blocks can be ignored.
148    assert self.clobbered_blocks.size() == 0
149    return sha1(self.data).hexdigest()
150
151
152class Transfer(object):
153  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
154    self.tgt_name = tgt_name
155    self.src_name = src_name
156    self.tgt_ranges = tgt_ranges
157    self.src_ranges = src_ranges
158    self.style = style
159    self.intact = (getattr(tgt_ranges, "monotonic", False) and
160                   getattr(src_ranges, "monotonic", False))
161
162    # We use OrderedDict rather than dict so that the output is repeatable;
163    # otherwise it would depend on the hash values of the Transfer objects.
164    self.goes_before = OrderedDict()
165    self.goes_after = OrderedDict()
166
167    self.stash_before = []
168    self.use_stash = []
169
170    self.id = len(by_id)
171    by_id.append(self)
172
173  def NetStashChange(self):
174    return (sum(sr.size() for (_, sr) in self.stash_before) -
175            sum(sr.size() for (_, sr) in self.use_stash))
176
177  def ConvertToNew(self):
178    assert self.style != "new"
179    self.use_stash = []
180    self.style = "new"
181    self.src_ranges = RangeSet()
182
183  def __str__(self):
184    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
185            " to " + str(self.tgt_ranges) + ">")
186
187
188# BlockImageDiff works on two image objects.  An image object is
189# anything that provides the following attributes:
190#
191#    blocksize: the size in bytes of a block, currently must be 4096.
192#
193#    total_blocks: the total size of the partition/image, in blocks.
194#
195#    care_map: a RangeSet containing which blocks (in the range [0,
196#      total_blocks) we actually care about; i.e. which blocks contain
197#      data.
198#
199#    file_map: a dict that partitions the blocks contained in care_map
200#      into smaller domains that are useful for doing diffs on.
201#      (Typically a domain is a file, and the key in file_map is the
202#      pathname.)
203#
204#    clobbered_blocks: a RangeSet containing which blocks contain data
205#      but may be altered by the FS. They need to be excluded when
206#      verifying the partition integrity.
207#
208#    ReadRangeSet(): a function that takes a RangeSet and returns the
209#      data contained in the image blocks of that RangeSet.  The data
210#      is returned as a list or tuple of strings; concatenating the
211#      elements together should produce the requested data.
212#      Implementations are free to break up the data into list/tuple
213#      elements in any way that is convenient.
214#
215#    TotalSha1(): a function that returns (as a hex string) the SHA-1
216#      hash of all the data in the image (ie, all the blocks in the
217#      care_map minus clobbered_blocks, or including the clobbered
218#      blocks if include_clobbered_blocks is True).
219#
220# When creating a BlockImageDiff, the src image may be None, in which
221# case the list of transfers produced will never read from the
222# original image.
223
224class BlockImageDiff(object):
225  def __init__(self, tgt, src=None, threads=None, version=3):
226    if threads is None:
227      threads = multiprocessing.cpu_count() // 2
228      if threads == 0:
229        threads = 1
230    self.threads = threads
231    self.version = version
232    self.transfers = []
233    self.src_basenames = {}
234    self.src_numpatterns = {}
235
236    assert version in (1, 2, 3)
237
238    self.tgt = tgt
239    if src is None:
240      src = EmptyImage()
241    self.src = src
242
243    # The updater code that installs the patch always uses 4k blocks.
244    assert tgt.blocksize == 4096
245    assert src.blocksize == 4096
246
247    # The range sets in each filemap should comprise a partition of
248    # the care map.
249    self.AssertPartition(src.care_map, src.file_map.values())
250    self.AssertPartition(tgt.care_map, tgt.file_map.values())
251
252  def Compute(self, prefix):
253    # When looking for a source file to use as the diff input for a
254    # target file, we try:
255    #   1) an exact path match if available, otherwise
256    #   2) a exact basename match if available, otherwise
257    #   3) a basename match after all runs of digits are replaced by
258    #      "#" if available, otherwise
259    #   4) we have no source for this target.
260    self.AbbreviateSourceNames()
261    self.FindTransfers()
262
263    # Find the ordering dependencies among transfers (this is O(n^2)
264    # in the number of transfers).
265    self.GenerateDigraph()
266    # Find a sequence of transfers that satisfies as many ordering
267    # dependencies as possible (heuristically).
268    self.FindVertexSequence()
269    # Fix up the ordering dependencies that the sequence didn't
270    # satisfy.
271    if self.version == 1:
272      self.RemoveBackwardEdges()
273    else:
274      self.ReverseBackwardEdges()
275      self.ImproveVertexSequence()
276
277    # Ensure the runtime stash size is under the limit.
278    if self.version >= 2 and common.OPTIONS.cache_size is not None:
279      self.ReviseStashSize()
280
281    # Double-check our work.
282    self.AssertSequenceGood()
283
284    self.ComputePatches(prefix)
285    self.WriteTransfers(prefix)
286
287  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
288    data = source.ReadRangeSet(ranges)
289    ctx = sha1()
290
291    for p in data:
292      ctx.update(p)
293
294    return ctx.hexdigest()
295
296  def WriteTransfers(self, prefix):
297    out = []
298
299    total = 0
300
301    stashes = {}
302    stashed_blocks = 0
303    max_stashed_blocks = 0
304
305    free_stash_ids = []
306    next_stash_id = 0
307
308    for xf in self.transfers:
309
310      if self.version < 2:
311        assert not xf.stash_before
312        assert not xf.use_stash
313
314      for s, sr in xf.stash_before:
315        assert s not in stashes
316        if free_stash_ids:
317          sid = heapq.heappop(free_stash_ids)
318        else:
319          sid = next_stash_id
320          next_stash_id += 1
321        stashes[s] = sid
322        stashed_blocks += sr.size()
323        if self.version == 2:
324          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
325        else:
326          sh = self.HashBlocks(self.src, sr)
327          if sh in stashes:
328            stashes[sh] += 1
329          else:
330            stashes[sh] = 1
331            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
332
333      if stashed_blocks > max_stashed_blocks:
334        max_stashed_blocks = stashed_blocks
335
336      free_string = []
337
338      if self.version == 1:
339        src_str = xf.src_ranges.to_string_raw()
340      elif self.version >= 2:
341
342        #   <# blocks> <src ranges>
343        #     OR
344        #   <# blocks> <src ranges> <src locs> <stash refs...>
345        #     OR
346        #   <# blocks> - <stash refs...>
347
348        size = xf.src_ranges.size()
349        src_str = [str(size)]
350
351        unstashed_src_ranges = xf.src_ranges
352        mapped_stashes = []
353        for s, sr in xf.use_stash:
354          sid = stashes.pop(s)
355          stashed_blocks -= sr.size()
356          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
357          sh = self.HashBlocks(self.src, sr)
358          sr = xf.src_ranges.map_within(sr)
359          mapped_stashes.append(sr)
360          if self.version == 2:
361            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
362          else:
363            assert sh in stashes
364            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
365            stashes[sh] -= 1
366            if stashes[sh] == 0:
367              free_string.append("free %s\n" % (sh))
368              stashes.pop(sh)
369          heapq.heappush(free_stash_ids, sid)
370
371        if unstashed_src_ranges:
372          src_str.insert(1, unstashed_src_ranges.to_string_raw())
373          if xf.use_stash:
374            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
375            src_str.insert(2, mapped_unstashed.to_string_raw())
376            mapped_stashes.append(mapped_unstashed)
377            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
378        else:
379          src_str.insert(1, "-")
380          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
381
382        src_str = " ".join(src_str)
383
384      # all versions:
385      #   zero <rangeset>
386      #   new <rangeset>
387      #   erase <rangeset>
388      #
389      # version 1:
390      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
391      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
392      #   move <src rangeset> <tgt rangeset>
393      #
394      # version 2:
395      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
396      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
397      #   move <tgt rangeset> <src_str>
398      #
399      # version 3:
400      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
401      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
402      #   move hash <tgt rangeset> <src_str>
403
404      tgt_size = xf.tgt_ranges.size()
405
406      if xf.style == "new":
407        assert xf.tgt_ranges
408        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
409        total += tgt_size
410      elif xf.style == "move":
411        assert xf.tgt_ranges
412        assert xf.src_ranges.size() == tgt_size
413        if xf.src_ranges != xf.tgt_ranges:
414          if self.version == 1:
415            out.append("%s %s %s\n" % (
416                xf.style,
417                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
418          elif self.version == 2:
419            out.append("%s %s %s\n" % (
420                xf.style,
421                xf.tgt_ranges.to_string_raw(), src_str))
422          elif self.version >= 3:
423            # take into account automatic stashing of overlapping blocks
424            if xf.src_ranges.overlaps(xf.tgt_ranges):
425              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
426              if temp_stash_usage > max_stashed_blocks:
427                max_stashed_blocks = temp_stash_usage
428
429            out.append("%s %s %s %s\n" % (
430                xf.style,
431                self.HashBlocks(self.tgt, xf.tgt_ranges),
432                xf.tgt_ranges.to_string_raw(), src_str))
433          total += tgt_size
434      elif xf.style in ("bsdiff", "imgdiff"):
435        assert xf.tgt_ranges
436        assert xf.src_ranges
437        if self.version == 1:
438          out.append("%s %d %d %s %s\n" % (
439              xf.style, xf.patch_start, xf.patch_len,
440              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
441        elif self.version == 2:
442          out.append("%s %d %d %s %s\n" % (
443              xf.style, xf.patch_start, xf.patch_len,
444              xf.tgt_ranges.to_string_raw(), src_str))
445        elif self.version >= 3:
446          # take into account automatic stashing of overlapping blocks
447          if xf.src_ranges.overlaps(xf.tgt_ranges):
448            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
449            if temp_stash_usage > max_stashed_blocks:
450              max_stashed_blocks = temp_stash_usage
451
452          out.append("%s %d %d %s %s %s %s\n" % (
453              xf.style,
454              xf.patch_start, xf.patch_len,
455              self.HashBlocks(self.src, xf.src_ranges),
456              self.HashBlocks(self.tgt, xf.tgt_ranges),
457              xf.tgt_ranges.to_string_raw(), src_str))
458        total += tgt_size
459      elif xf.style == "zero":
460        assert xf.tgt_ranges
461        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
462        if to_zero:
463          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
464          total += to_zero.size()
465      else:
466        raise ValueError("unknown transfer style '%s'\n" % xf.style)
467
468      if free_string:
469        out.append("".join(free_string))
470
471      if self.version >= 2:
472        # Sanity check: abort if we're going to need more stash space than
473        # the allowed size (cache_size * threshold). There are two purposes
474        # of having a threshold here. a) Part of the cache may have been
475        # occupied by some recovery logs. b) It will buy us some time to deal
476        # with the oversize issue.
477        cache_size = common.OPTIONS.cache_size
478        stash_threshold = common.OPTIONS.stash_threshold
479        max_allowed = cache_size * stash_threshold
480        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
481               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
482                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
483                   self.tgt.blocksize, max_allowed, cache_size,
484                   stash_threshold)
485
486    # Zero out extended blocks as a workaround for bug 20881595.
487    if self.tgt.extended:
488      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
489
490    # We erase all the blocks on the partition that a) don't contain useful
491    # data in the new image and b) will not be touched by dm-verity.
492    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
493    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
494    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
495    if new_dontcare:
496      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
497
498    out.insert(0, "%d\n" % (self.version,))   # format version number
499    out.insert(1, str(total) + "\n")
500    if self.version >= 2:
501      # version 2 only: after the total block count, we give the number
502      # of stash slots needed, and the maximum size needed (in blocks)
503      out.insert(2, str(next_stash_id) + "\n")
504      out.insert(3, str(max_stashed_blocks) + "\n")
505
506    with open(prefix + ".transfer.list", "wb") as f:
507      for i in out:
508        f.write(i)
509
510    if self.version >= 2:
511      max_stashed_size = max_stashed_blocks * self.tgt.blocksize
512      max_allowed = common.OPTIONS.cache_size * common.OPTIONS.stash_threshold
513      print("max stashed blocks: %d  (%d bytes), limit: %d bytes (%.2f%%)\n" % (
514          max_stashed_blocks, max_stashed_size, max_allowed,
515          max_stashed_size * 100.0 / max_allowed))
516
517  def ReviseStashSize(self):
518    print("Revising stash size...")
519    stashes = {}
520
521    # Create the map between a stash and its def/use points. For example, for a
522    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
523    for xf in self.transfers:
524      # Command xf defines (stores) all the stashes in stash_before.
525      for idx, sr in xf.stash_before:
526        stashes[idx] = (sr, xf)
527
528      # Record all the stashes command xf uses.
529      for idx, _ in xf.use_stash:
530        stashes[idx] += (xf,)
531
532    # Compute the maximum blocks available for stash based on /cache size and
533    # the threshold.
534    cache_size = common.OPTIONS.cache_size
535    stash_threshold = common.OPTIONS.stash_threshold
536    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
537
538    stashed_blocks = 0
539    new_blocks = 0
540
541    # Now go through all the commands. Compute the required stash size on the
542    # fly. If a command requires excess stash than available, it deletes the
543    # stash by replacing the command that uses the stash with a "new" command
544    # instead.
545    for xf in self.transfers:
546      replaced_cmds = []
547
548      # xf.stash_before generates explicit stash commands.
549      for idx, sr in xf.stash_before:
550        if stashed_blocks + sr.size() > max_allowed:
551          # We cannot stash this one for a later command. Find out the command
552          # that will use this stash and replace the command with "new".
553          use_cmd = stashes[idx][2]
554          replaced_cmds.append(use_cmd)
555          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
556        else:
557          stashed_blocks += sr.size()
558
559      # xf.use_stash generates free commands.
560      for _, sr in xf.use_stash:
561        stashed_blocks -= sr.size()
562
563      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
564      # ComputePatches(), they both have the style of "diff".
565      if xf.style == "diff" and self.version >= 3:
566        assert xf.tgt_ranges and xf.src_ranges
567        if xf.src_ranges.overlaps(xf.tgt_ranges):
568          if stashed_blocks + xf.src_ranges.size() > max_allowed:
569            replaced_cmds.append(xf)
570            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
571
572      # Replace the commands in replaced_cmds with "new"s.
573      for cmd in replaced_cmds:
574        # It no longer uses any commands in "use_stash". Remove the def points
575        # for all those stashes.
576        for idx, sr in cmd.use_stash:
577          def_cmd = stashes[idx][1]
578          assert (idx, sr) in def_cmd.stash_before
579          def_cmd.stash_before.remove((idx, sr))
580          new_blocks += sr.size()
581
582        cmd.ConvertToNew()
583
584    print("  Total %d blocks are packed as new blocks due to insufficient "
585          "cache size." % (new_blocks,))
586
587  def ComputePatches(self, prefix):
588    print("Reticulating splines...")
589    diff_q = []
590    patch_num = 0
591    with open(prefix + ".new.dat", "wb") as new_f:
592      for xf in self.transfers:
593        if xf.style == "zero":
594          pass
595        elif xf.style == "new":
596          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
597            new_f.write(piece)
598        elif xf.style == "diff":
599          src = self.src.ReadRangeSet(xf.src_ranges)
600          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
601
602          # We can't compare src and tgt directly because they may have
603          # the same content but be broken up into blocks differently, eg:
604          #
605          #    ["he", "llo"]  vs  ["h", "ello"]
606          #
607          # We want those to compare equal, ideally without having to
608          # actually concatenate the strings (these may be tens of
609          # megabytes).
610
611          src_sha1 = sha1()
612          for p in src:
613            src_sha1.update(p)
614          tgt_sha1 = sha1()
615          tgt_size = 0
616          for p in tgt:
617            tgt_sha1.update(p)
618            tgt_size += len(p)
619
620          if src_sha1.digest() == tgt_sha1.digest():
621            # These are identical; we don't need to generate a patch,
622            # just issue copy commands on the device.
623            xf.style = "move"
624          else:
625            # For files in zip format (eg, APKs, JARs, etc.) we would
626            # like to use imgdiff -z if possible (because it usually
627            # produces significantly smaller patches than bsdiff).
628            # This is permissible if:
629            #
630            #  - the source and target files are monotonic (ie, the
631            #    data is stored with blocks in increasing order), and
632            #  - we haven't removed any blocks from the source set.
633            #
634            # If these conditions are satisfied then appending all the
635            # blocks in the set together in order will produce a valid
636            # zip file (plus possibly extra zeros in the last block),
637            # which is what imgdiff needs to operate.  (imgdiff is
638            # fine with extra zeros at the end of the file.)
639            imgdiff = (xf.intact and
640                       xf.tgt_name.split(".")[-1].lower()
641                       in ("apk", "jar", "zip"))
642            xf.style = "imgdiff" if imgdiff else "bsdiff"
643            diff_q.append((tgt_size, src, tgt, xf, patch_num))
644            patch_num += 1
645
646        else:
647          assert False, "unknown style " + xf.style
648
649    if diff_q:
650      if self.threads > 1:
651        print("Computing patches (using %d threads)..." % (self.threads,))
652      else:
653        print("Computing patches...")
654      diff_q.sort()
655
656      patches = [None] * patch_num
657
658      # TODO: Rewrite with multiprocessing.ThreadPool?
659      lock = threading.Lock()
660      def diff_worker():
661        while True:
662          with lock:
663            if not diff_q:
664              return
665            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
666          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
667          size = len(patch)
668          with lock:
669            patches[patchnum] = (patch, xf)
670            print("%10d %10d (%6.2f%%) %7s %s" % (
671                size, tgt_size, size * 100.0 / tgt_size, xf.style,
672                xf.tgt_name if xf.tgt_name == xf.src_name else (
673                    xf.tgt_name + " (from " + xf.src_name + ")")))
674
675      threads = [threading.Thread(target=diff_worker)
676                 for _ in range(self.threads)]
677      for th in threads:
678        th.start()
679      while threads:
680        threads.pop().join()
681    else:
682      patches = []
683
684    p = 0
685    with open(prefix + ".patch.dat", "wb") as patch_f:
686      for patch, xf in patches:
687        xf.patch_start = p
688        xf.patch_len = len(patch)
689        patch_f.write(patch)
690        p += len(patch)
691
692  def AssertSequenceGood(self):
693    # Simulate the sequences of transfers we will output, and check that:
694    # - we never read a block after writing it, and
695    # - we write every block we care about exactly once.
696
697    # Start with no blocks having been touched yet.
698    touched = RangeSet()
699
700    # Imagine processing the transfers in order.
701    for xf in self.transfers:
702      # Check that the input blocks for this transfer haven't yet been touched.
703
704      x = xf.src_ranges
705      if self.version >= 2:
706        for _, sr in xf.use_stash:
707          x = x.subtract(sr)
708
709      assert not touched.overlaps(x)
710      # Check that the output blocks for this transfer haven't yet been touched.
711      assert not touched.overlaps(xf.tgt_ranges)
712      # Touch all the blocks written by this transfer.
713      touched = touched.union(xf.tgt_ranges)
714
715    # Check that we've written every target block.
716    assert touched == self.tgt.care_map
717
718  def ImproveVertexSequence(self):
719    print("Improving vertex order...")
720
721    # At this point our digraph is acyclic; we reversed any edges that
722    # were backwards in the heuristically-generated sequence.  The
723    # previously-generated order is still acceptable, but we hope to
724    # find a better order that needs less memory for stashed data.
725    # Now we do a topological sort to generate a new vertex order,
726    # using a greedy algorithm to choose which vertex goes next
727    # whenever we have a choice.
728
729    # Make a copy of the edge set; this copy will get destroyed by the
730    # algorithm.
731    for xf in self.transfers:
732      xf.incoming = xf.goes_after.copy()
733      xf.outgoing = xf.goes_before.copy()
734
735    L = []   # the new vertex order
736
737    # S is the set of sources in the remaining graph; we always choose
738    # the one that leaves the least amount of stashed data after it's
739    # executed.
740    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
741         if not u.incoming]
742    heapq.heapify(S)
743
744    while S:
745      _, _, xf = heapq.heappop(S)
746      L.append(xf)
747      for u in xf.outgoing:
748        del u.incoming[xf]
749        if not u.incoming:
750          heapq.heappush(S, (u.NetStashChange(), u.order, u))
751
752    # if this fails then our graph had a cycle.
753    assert len(L) == len(self.transfers)
754
755    self.transfers = L
756    for i, xf in enumerate(L):
757      xf.order = i
758
759  def RemoveBackwardEdges(self):
760    print("Removing backward edges...")
761    in_order = 0
762    out_of_order = 0
763    lost_source = 0
764
765    for xf in self.transfers:
766      lost = 0
767      size = xf.src_ranges.size()
768      for u in xf.goes_before:
769        # xf should go before u
770        if xf.order < u.order:
771          # it does, hurray!
772          in_order += 1
773        else:
774          # it doesn't, boo.  trim the blocks that u writes from xf's
775          # source, so that xf can go after u.
776          out_of_order += 1
777          assert xf.src_ranges.overlaps(u.tgt_ranges)
778          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
779          xf.intact = False
780
781      if xf.style == "diff" and not xf.src_ranges:
782        # nothing left to diff from; treat as new data
783        xf.style = "new"
784
785      lost = size - xf.src_ranges.size()
786      lost_source += lost
787
788    print(("  %d/%d dependencies (%.2f%%) were violated; "
789           "%d source blocks removed.") %
790          (out_of_order, in_order + out_of_order,
791           (out_of_order * 100.0 / (in_order + out_of_order))
792           if (in_order + out_of_order) else 0.0,
793           lost_source))
794
795  def ReverseBackwardEdges(self):
796    print("Reversing backward edges...")
797    in_order = 0
798    out_of_order = 0
799    stashes = 0
800    stash_size = 0
801
802    for xf in self.transfers:
803      for u in xf.goes_before.copy():
804        # xf should go before u
805        if xf.order < u.order:
806          # it does, hurray!
807          in_order += 1
808        else:
809          # it doesn't, boo.  modify u to stash the blocks that it
810          # writes that xf wants to read, and then require u to go
811          # before xf.
812          out_of_order += 1
813
814          overlap = xf.src_ranges.intersect(u.tgt_ranges)
815          assert overlap
816
817          u.stash_before.append((stashes, overlap))
818          xf.use_stash.append((stashes, overlap))
819          stashes += 1
820          stash_size += overlap.size()
821
822          # reverse the edge direction; now xf must go after u
823          del xf.goes_before[u]
824          del u.goes_after[xf]
825          xf.goes_after[u] = None    # value doesn't matter
826          u.goes_before[xf] = None
827
828    print(("  %d/%d dependencies (%.2f%%) were violated; "
829           "%d source blocks stashed.") %
830          (out_of_order, in_order + out_of_order,
831           (out_of_order * 100.0 / (in_order + out_of_order))
832           if (in_order + out_of_order) else 0.0,
833           stash_size))
834
835  def FindVertexSequence(self):
836    print("Finding vertex sequence...")
837
838    # This is based on "A Fast & Effective Heuristic for the Feedback
839    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
840    # it as starting with the digraph G and moving all the vertices to
841    # be on a horizontal line in some order, trying to minimize the
842    # number of edges that end up pointing to the left.  Left-pointing
843    # edges will get removed to turn the digraph into a DAG.  In this
844    # case each edge has a weight which is the number of source blocks
845    # we'll lose if that edge is removed; we try to minimize the total
846    # weight rather than just the number of edges.
847
848    # Make a copy of the edge set; this copy will get destroyed by the
849    # algorithm.
850    for xf in self.transfers:
851      xf.incoming = xf.goes_after.copy()
852      xf.outgoing = xf.goes_before.copy()
853
854    # We use an OrderedDict instead of just a set so that the output
855    # is repeatable; otherwise it would depend on the hash values of
856    # the transfer objects.
857    G = OrderedDict()
858    for xf in self.transfers:
859      G[xf] = None
860    s1 = deque()  # the left side of the sequence, built from left to right
861    s2 = deque()  # the right side of the sequence, built from right to left
862
863    while G:
864
865      # Put all sinks at the end of the sequence.
866      while True:
867        sinks = [u for u in G if not u.outgoing]
868        if not sinks:
869          break
870        for u in sinks:
871          s2.appendleft(u)
872          del G[u]
873          for iu in u.incoming:
874            del iu.outgoing[u]
875
876      # Put all the sources at the beginning of the sequence.
877      while True:
878        sources = [u for u in G if not u.incoming]
879        if not sources:
880          break
881        for u in sources:
882          s1.append(u)
883          del G[u]
884          for iu in u.outgoing:
885            del iu.incoming[u]
886
887      if not G:
888        break
889
890      # Find the "best" vertex to put next.  "Best" is the one that
891      # maximizes the net difference in source blocks saved we get by
892      # pretending it's a source rather than a sink.
893
894      max_d = None
895      best_u = None
896      for u in G:
897        d = sum(u.outgoing.values()) - sum(u.incoming.values())
898        if best_u is None or d > max_d:
899          max_d = d
900          best_u = u
901
902      u = best_u
903      s1.append(u)
904      del G[u]
905      for iu in u.outgoing:
906        del iu.incoming[u]
907      for iu in u.incoming:
908        del iu.outgoing[u]
909
910    # Now record the sequence in the 'order' field of each transfer,
911    # and by rearranging self.transfers to be in the chosen sequence.
912
913    new_transfers = []
914    for x in itertools.chain(s1, s2):
915      x.order = len(new_transfers)
916      new_transfers.append(x)
917      del x.incoming
918      del x.outgoing
919
920    self.transfers = new_transfers
921
922  def GenerateDigraph(self):
923    print("Generating digraph...")
924    for a in self.transfers:
925      for b in self.transfers:
926        if a is b:
927          continue
928
929        # If the blocks written by A are read by B, then B needs to go before A.
930        i = a.tgt_ranges.intersect(b.src_ranges)
931        if i:
932          if b.src_name == "__ZERO":
933            # the cost of removing source blocks for the __ZERO domain
934            # is (nearly) zero.
935            size = 0
936          else:
937            size = i.size()
938          b.goes_before[a] = size
939          a.goes_after[b] = size
940
941  def FindTransfers(self):
942    """Parse the file_map to generate all the transfers."""
943
944    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
945                    split=False):
946      """Wrapper function for adding a Transfer().
947
948      For BBOTA v3, we need to stash source blocks for resumable feature.
949      However, with the growth of file size and the shrink of the cache
950      partition source blocks are too large to be stashed. If a file occupies
951      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
952      into smaller pieces by getting multiple Transfer()s.
953
954      The downside is that after splitting, we can no longer use imgdiff but
955      only bsdiff."""
956
957      MAX_BLOCKS_PER_DIFF_TRANSFER = 1024
958
959      # We care about diff transfers only.
960      if style != "diff" or not split:
961        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
962        return
963
964      # Change nothing for small files.
965      if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and
966          src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER):
967        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
968        return
969
970      pieces = 0
971      while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and
972             src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER):
973        tgt_split_name = "%s-%d" % (tgt_name, pieces)
974        src_split_name = "%s-%d" % (src_name, pieces)
975        tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER)
976        src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER)
977        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
978                 by_id)
979
980        tgt_ranges = tgt_ranges.subtract(tgt_first)
981        src_ranges = src_ranges.subtract(src_first)
982        pieces += 1
983
984      # Handle remaining blocks.
985      if tgt_ranges.size() or src_ranges.size():
986        # Must be both non-empty.
987        assert tgt_ranges.size() and src_ranges.size()
988        tgt_split_name = "%s-%d" % (tgt_name, pieces)
989        src_split_name = "%s-%d" % (src_name, pieces)
990        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
991                 by_id)
992
993    empty = RangeSet()
994    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
995      if tgt_fn == "__ZERO":
996        # the special "__ZERO" domain is all the blocks not contained
997        # in any file and that are filled with zeros.  We have a
998        # special transfer style for zero blocks.
999        src_ranges = self.src.file_map.get("__ZERO", empty)
1000        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1001                    "zero", self.transfers)
1002        continue
1003
1004      elif tgt_fn == "__COPY":
1005        # "__COPY" domain includes all the blocks not contained in any
1006        # file and that need to be copied unconditionally to the target.
1007        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1008        continue
1009
1010      elif tgt_fn in self.src.file_map:
1011        # Look for an exact pathname match in the source.
1012        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1013                    "diff", self.transfers, self.version >= 3)
1014        continue
1015
1016      b = os.path.basename(tgt_fn)
1017      if b in self.src_basenames:
1018        # Look for an exact basename match in the source.
1019        src_fn = self.src_basenames[b]
1020        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1021                    "diff", self.transfers, self.version >= 3)
1022        continue
1023
1024      b = re.sub("[0-9]+", "#", b)
1025      if b in self.src_numpatterns:
1026        # Look for a 'number pattern' match (a basename match after
1027        # all runs of digits are replaced by "#").  (This is useful
1028        # for .so files that contain version numbers in the filename
1029        # that get bumped.)
1030        src_fn = self.src_numpatterns[b]
1031        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1032                    "diff", self.transfers, self.version >= 3)
1033        continue
1034
1035      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1036
1037  def AbbreviateSourceNames(self):
1038    for k in self.src.file_map.keys():
1039      b = os.path.basename(k)
1040      self.src_basenames[b] = k
1041      b = re.sub("[0-9]+", "#", b)
1042      self.src_numpatterns[b] = k
1043
1044  @staticmethod
1045  def AssertPartition(total, seq):
1046    """Assert that all the RangeSets in 'seq' form a partition of the
1047    'total' RangeSet (ie, they are nonintersecting and their union
1048    equals 'total')."""
1049    so_far = RangeSet()
1050    for i in seq:
1051      assert not so_far.overlaps(i)
1052      so_far = so_far.union(i)
1053    assert so_far == total
1054