blockimgdiff.py revision 28f6f9c3deb4d23fd627d15631d682a5cfa989fc
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import common
20import heapq
21import itertools
22import multiprocessing
23import os
24import re
25import subprocess
26import threading
27import tempfile
28
29from rangelib import RangeSet
30
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34
35def compute_patch(src, tgt, imgdiff=False):
36  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
37  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
38  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
39  os.close(patchfd)
40
41  try:
42    with os.fdopen(srcfd, "wb") as f_src:
43      for p in src:
44        f_src.write(p)
45
46    with os.fdopen(tgtfd, "wb") as f_tgt:
47      for p in tgt:
48        f_tgt.write(p)
49    try:
50      os.unlink(patchfile)
51    except OSError:
52      pass
53    if imgdiff:
54      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
55                          stdout=open("/dev/null", "a"),
56                          stderr=subprocess.STDOUT)
57    else:
58      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
59
60    if p:
61      raise ValueError("diff failed: " + str(p))
62
63    with open(patchfile, "rb") as f:
64      return f.read()
65  finally:
66    try:
67      os.unlink(srcfile)
68      os.unlink(tgtfile)
69      os.unlink(patchfile)
70    except OSError:
71      pass
72
73
74class Image(object):
75  def ReadRangeSet(self, ranges):
76    raise NotImplementedError
77
78  def TotalSha1(self, include_clobbered_blocks=False):
79    raise NotImplementedError
80
81
82class EmptyImage(Image):
83  """A zero-length image."""
84  blocksize = 4096
85  care_map = RangeSet()
86  clobbered_blocks = RangeSet()
87  extended = RangeSet()
88  total_blocks = 0
89  file_map = {}
90  def ReadRangeSet(self, ranges):
91    return ()
92  def TotalSha1(self, include_clobbered_blocks=False):
93    # EmptyImage always carries empty clobbered_blocks, so
94    # include_clobbered_blocks can be ignored.
95    assert self.clobbered_blocks.size() == 0
96    return sha1().hexdigest()
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    padded = False
110    if partial > 0:
111      if trim:
112        self.data = self.data[:-partial]
113      elif pad:
114        self.data += '\0' * (self.blocksize - partial)
115        padded = True
116      else:
117        raise ValueError(("data for DataImage must be multiple of %d bytes "
118                          "unless trim or pad is specified") %
119                         (self.blocksize,))
120
121    assert len(self.data) % self.blocksize == 0
122
123    self.total_blocks = len(self.data) / self.blocksize
124    self.care_map = RangeSet(data=(0, self.total_blocks))
125    # When the last block is padded, we always write the whole block even for
126    # incremental OTAs. Because otherwise the last block may get skipped if
127    # unchanged for an incremental, but would fail the post-install
128    # verification if it has non-zero contents in the padding bytes.
129    # Bug: 23828506
130    if padded:
131      self.clobbered_blocks = RangeSet(
132          data=(self.total_blocks-1, self.total_blocks))
133    else:
134      self.clobbered_blocks = RangeSet()
135    self.extended = RangeSet()
136
137    zero_blocks = []
138    nonzero_blocks = []
139    reference = '\0' * self.blocksize
140
141    for i in range(self.total_blocks-1 if padded else self.total_blocks):
142      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
143      if d == reference:
144        zero_blocks.append(i)
145        zero_blocks.append(i+1)
146      else:
147        nonzero_blocks.append(i)
148        nonzero_blocks.append(i+1)
149
150    self.file_map = {"__ZERO": RangeSet(zero_blocks),
151                     "__NONZERO": RangeSet(nonzero_blocks)}
152
153    if self.clobbered_blocks:
154      self.file_map["__COPY"] = self.clobbered_blocks
155
156  def ReadRangeSet(self, ranges):
157    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
158
159  def TotalSha1(self, include_clobbered_blocks=False):
160    if not include_clobbered_blocks:
161      ranges = self.care_map.subtract(self.clobbered_blocks)
162      return sha1(self.ReadRangeSet(ranges)).hexdigest()
163    else:
164      return sha1(self.data).hexdigest()
165
166
167class Transfer(object):
168  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
169    self.tgt_name = tgt_name
170    self.src_name = src_name
171    self.tgt_ranges = tgt_ranges
172    self.src_ranges = src_ranges
173    self.style = style
174    self.intact = (getattr(tgt_ranges, "monotonic", False) and
175                   getattr(src_ranges, "monotonic", False))
176
177    # We use OrderedDict rather than dict so that the output is repeatable;
178    # otherwise it would depend on the hash values of the Transfer objects.
179    self.goes_before = OrderedDict()
180    self.goes_after = OrderedDict()
181
182    self.stash_before = []
183    self.use_stash = []
184
185    self.id = len(by_id)
186    by_id.append(self)
187
188  def NetStashChange(self):
189    return (sum(sr.size() for (_, sr) in self.stash_before) -
190            sum(sr.size() for (_, sr) in self.use_stash))
191
192  def ConvertToNew(self):
193    assert self.style != "new"
194    self.use_stash = []
195    self.style = "new"
196    self.src_ranges = RangeSet()
197
198  def __str__(self):
199    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
200            " to " + str(self.tgt_ranges) + ">")
201
202
203# BlockImageDiff works on two image objects.  An image object is
204# anything that provides the following attributes:
205#
206#    blocksize: the size in bytes of a block, currently must be 4096.
207#
208#    total_blocks: the total size of the partition/image, in blocks.
209#
210#    care_map: a RangeSet containing which blocks (in the range [0,
211#      total_blocks) we actually care about; i.e. which blocks contain
212#      data.
213#
214#    file_map: a dict that partitions the blocks contained in care_map
215#      into smaller domains that are useful for doing diffs on.
216#      (Typically a domain is a file, and the key in file_map is the
217#      pathname.)
218#
219#    clobbered_blocks: a RangeSet containing which blocks contain data
220#      but may be altered by the FS. They need to be excluded when
221#      verifying the partition integrity.
222#
223#    ReadRangeSet(): a function that takes a RangeSet and returns the
224#      data contained in the image blocks of that RangeSet.  The data
225#      is returned as a list or tuple of strings; concatenating the
226#      elements together should produce the requested data.
227#      Implementations are free to break up the data into list/tuple
228#      elements in any way that is convenient.
229#
230#    TotalSha1(): a function that returns (as a hex string) the SHA-1
231#      hash of all the data in the image (ie, all the blocks in the
232#      care_map minus clobbered_blocks, or including the clobbered
233#      blocks if include_clobbered_blocks is True).
234#
235# When creating a BlockImageDiff, the src image may be None, in which
236# case the list of transfers produced will never read from the
237# original image.
238
239class BlockImageDiff(object):
240  def __init__(self, tgt, src=None, threads=None, version=3):
241    if threads is None:
242      threads = multiprocessing.cpu_count() // 2
243      if threads == 0:
244        threads = 1
245    self.threads = threads
246    self.version = version
247    self.transfers = []
248    self.src_basenames = {}
249    self.src_numpatterns = {}
250
251    assert version in (1, 2, 3)
252
253    self.tgt = tgt
254    if src is None:
255      src = EmptyImage()
256    self.src = src
257
258    # The updater code that installs the patch always uses 4k blocks.
259    assert tgt.blocksize == 4096
260    assert src.blocksize == 4096
261
262    # The range sets in each filemap should comprise a partition of
263    # the care map.
264    self.AssertPartition(src.care_map, src.file_map.values())
265    self.AssertPartition(tgt.care_map, tgt.file_map.values())
266
267  def Compute(self, prefix):
268    # When looking for a source file to use as the diff input for a
269    # target file, we try:
270    #   1) an exact path match if available, otherwise
271    #   2) a exact basename match if available, otherwise
272    #   3) a basename match after all runs of digits are replaced by
273    #      "#" if available, otherwise
274    #   4) we have no source for this target.
275    self.AbbreviateSourceNames()
276    self.FindTransfers()
277
278    # Find the ordering dependencies among transfers (this is O(n^2)
279    # in the number of transfers).
280    self.GenerateDigraph()
281    # Find a sequence of transfers that satisfies as many ordering
282    # dependencies as possible (heuristically).
283    self.FindVertexSequence()
284    # Fix up the ordering dependencies that the sequence didn't
285    # satisfy.
286    if self.version == 1:
287      self.RemoveBackwardEdges()
288    else:
289      self.ReverseBackwardEdges()
290      self.ImproveVertexSequence()
291
292    # Ensure the runtime stash size is under the limit.
293    if self.version >= 2 and common.OPTIONS.cache_size is not None:
294      self.ReviseStashSize()
295
296    # Double-check our work.
297    self.AssertSequenceGood()
298
299    self.ComputePatches(prefix)
300    self.WriteTransfers(prefix)
301
302  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
303    data = source.ReadRangeSet(ranges)
304    ctx = sha1()
305
306    for p in data:
307      ctx.update(p)
308
309    return ctx.hexdigest()
310
311  def WriteTransfers(self, prefix):
312    out = []
313
314    total = 0
315
316    stashes = {}
317    stashed_blocks = 0
318    max_stashed_blocks = 0
319
320    free_stash_ids = []
321    next_stash_id = 0
322
323    for xf in self.transfers:
324
325      if self.version < 2:
326        assert not xf.stash_before
327        assert not xf.use_stash
328
329      for s, sr in xf.stash_before:
330        assert s not in stashes
331        if free_stash_ids:
332          sid = heapq.heappop(free_stash_ids)
333        else:
334          sid = next_stash_id
335          next_stash_id += 1
336        stashes[s] = sid
337        stashed_blocks += sr.size()
338        if self.version == 2:
339          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
340        else:
341          sh = self.HashBlocks(self.src, sr)
342          if sh in stashes:
343            stashes[sh] += 1
344          else:
345            stashes[sh] = 1
346            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
347
348      if stashed_blocks > max_stashed_blocks:
349        max_stashed_blocks = stashed_blocks
350
351      free_string = []
352
353      if self.version == 1:
354        src_str = xf.src_ranges.to_string_raw()
355      elif self.version >= 2:
356
357        #   <# blocks> <src ranges>
358        #     OR
359        #   <# blocks> <src ranges> <src locs> <stash refs...>
360        #     OR
361        #   <# blocks> - <stash refs...>
362
363        size = xf.src_ranges.size()
364        src_str = [str(size)]
365
366        unstashed_src_ranges = xf.src_ranges
367        mapped_stashes = []
368        for s, sr in xf.use_stash:
369          sid = stashes.pop(s)
370          stashed_blocks -= sr.size()
371          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
372          sh = self.HashBlocks(self.src, sr)
373          sr = xf.src_ranges.map_within(sr)
374          mapped_stashes.append(sr)
375          if self.version == 2:
376            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
377          else:
378            assert sh in stashes
379            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
380            stashes[sh] -= 1
381            if stashes[sh] == 0:
382              free_string.append("free %s\n" % (sh))
383              stashes.pop(sh)
384          heapq.heappush(free_stash_ids, sid)
385
386        if unstashed_src_ranges:
387          src_str.insert(1, unstashed_src_ranges.to_string_raw())
388          if xf.use_stash:
389            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
390            src_str.insert(2, mapped_unstashed.to_string_raw())
391            mapped_stashes.append(mapped_unstashed)
392            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
393        else:
394          src_str.insert(1, "-")
395          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
396
397        src_str = " ".join(src_str)
398
399      # all versions:
400      #   zero <rangeset>
401      #   new <rangeset>
402      #   erase <rangeset>
403      #
404      # version 1:
405      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
406      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
407      #   move <src rangeset> <tgt rangeset>
408      #
409      # version 2:
410      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
411      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
412      #   move <tgt rangeset> <src_str>
413      #
414      # version 3:
415      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
416      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
417      #   move hash <tgt rangeset> <src_str>
418
419      tgt_size = xf.tgt_ranges.size()
420
421      if xf.style == "new":
422        assert xf.tgt_ranges
423        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
424        total += tgt_size
425      elif xf.style == "move":
426        assert xf.tgt_ranges
427        assert xf.src_ranges.size() == tgt_size
428        if xf.src_ranges != xf.tgt_ranges:
429          if self.version == 1:
430            out.append("%s %s %s\n" % (
431                xf.style,
432                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
433          elif self.version == 2:
434            out.append("%s %s %s\n" % (
435                xf.style,
436                xf.tgt_ranges.to_string_raw(), src_str))
437          elif self.version >= 3:
438            # take into account automatic stashing of overlapping blocks
439            if xf.src_ranges.overlaps(xf.tgt_ranges):
440              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
441              if temp_stash_usage > max_stashed_blocks:
442                max_stashed_blocks = temp_stash_usage
443
444            out.append("%s %s %s %s\n" % (
445                xf.style,
446                self.HashBlocks(self.tgt, xf.tgt_ranges),
447                xf.tgt_ranges.to_string_raw(), src_str))
448          total += tgt_size
449      elif xf.style in ("bsdiff", "imgdiff"):
450        assert xf.tgt_ranges
451        assert xf.src_ranges
452        if self.version == 1:
453          out.append("%s %d %d %s %s\n" % (
454              xf.style, xf.patch_start, xf.patch_len,
455              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
456        elif self.version == 2:
457          out.append("%s %d %d %s %s\n" % (
458              xf.style, xf.patch_start, xf.patch_len,
459              xf.tgt_ranges.to_string_raw(), src_str))
460        elif self.version >= 3:
461          # take into account automatic stashing of overlapping blocks
462          if xf.src_ranges.overlaps(xf.tgt_ranges):
463            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
464            if temp_stash_usage > max_stashed_blocks:
465              max_stashed_blocks = temp_stash_usage
466
467          out.append("%s %d %d %s %s %s %s\n" % (
468              xf.style,
469              xf.patch_start, xf.patch_len,
470              self.HashBlocks(self.src, xf.src_ranges),
471              self.HashBlocks(self.tgt, xf.tgt_ranges),
472              xf.tgt_ranges.to_string_raw(), src_str))
473        total += tgt_size
474      elif xf.style == "zero":
475        assert xf.tgt_ranges
476        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
477        if to_zero:
478          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
479          total += to_zero.size()
480      else:
481        raise ValueError("unknown transfer style '%s'\n" % xf.style)
482
483      if free_string:
484        out.append("".join(free_string))
485
486      if self.version >= 2:
487        # Sanity check: abort if we're going to need more stash space than
488        # the allowed size (cache_size * threshold). There are two purposes
489        # of having a threshold here. a) Part of the cache may have been
490        # occupied by some recovery logs. b) It will buy us some time to deal
491        # with the oversize issue.
492        cache_size = common.OPTIONS.cache_size
493        stash_threshold = common.OPTIONS.stash_threshold
494        max_allowed = cache_size * stash_threshold
495        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
496               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
497                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
498                   self.tgt.blocksize, max_allowed, cache_size,
499                   stash_threshold)
500
501    # Zero out extended blocks as a workaround for bug 20881595.
502    if self.tgt.extended:
503      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
504
505    # We erase all the blocks on the partition that a) don't contain useful
506    # data in the new image and b) will not be touched by dm-verity.
507    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
508    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
509    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
510    if new_dontcare:
511      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
512
513    out.insert(0, "%d\n" % (self.version,))   # format version number
514    out.insert(1, str(total) + "\n")
515    if self.version >= 2:
516      # version 2 only: after the total block count, we give the number
517      # of stash slots needed, and the maximum size needed (in blocks)
518      out.insert(2, str(next_stash_id) + "\n")
519      out.insert(3, str(max_stashed_blocks) + "\n")
520
521    with open(prefix + ".transfer.list", "wb") as f:
522      for i in out:
523        f.write(i)
524
525    if self.version >= 2:
526      max_stashed_size = max_stashed_blocks * self.tgt.blocksize
527      max_allowed = common.OPTIONS.cache_size * common.OPTIONS.stash_threshold
528      print("max stashed blocks: %d  (%d bytes), limit: %d bytes (%.2f%%)\n" % (
529          max_stashed_blocks, max_stashed_size, max_allowed,
530          max_stashed_size * 100.0 / max_allowed))
531
532  def ReviseStashSize(self):
533    print("Revising stash size...")
534    stashes = {}
535
536    # Create the map between a stash and its def/use points. For example, for a
537    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
538    for xf in self.transfers:
539      # Command xf defines (stores) all the stashes in stash_before.
540      for idx, sr in xf.stash_before:
541        stashes[idx] = (sr, xf)
542
543      # Record all the stashes command xf uses.
544      for idx, _ in xf.use_stash:
545        stashes[idx] += (xf,)
546
547    # Compute the maximum blocks available for stash based on /cache size and
548    # the threshold.
549    cache_size = common.OPTIONS.cache_size
550    stash_threshold = common.OPTIONS.stash_threshold
551    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
552
553    stashed_blocks = 0
554    new_blocks = 0
555
556    # Now go through all the commands. Compute the required stash size on the
557    # fly. If a command requires excess stash than available, it deletes the
558    # stash by replacing the command that uses the stash with a "new" command
559    # instead.
560    for xf in self.transfers:
561      replaced_cmds = []
562
563      # xf.stash_before generates explicit stash commands.
564      for idx, sr in xf.stash_before:
565        if stashed_blocks + sr.size() > max_allowed:
566          # We cannot stash this one for a later command. Find out the command
567          # that will use this stash and replace the command with "new".
568          use_cmd = stashes[idx][2]
569          replaced_cmds.append(use_cmd)
570          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
571        else:
572          stashed_blocks += sr.size()
573
574      # xf.use_stash generates free commands.
575      for _, sr in xf.use_stash:
576        stashed_blocks -= sr.size()
577
578      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
579      # ComputePatches(), they both have the style of "diff".
580      if xf.style == "diff" and self.version >= 3:
581        assert xf.tgt_ranges and xf.src_ranges
582        if xf.src_ranges.overlaps(xf.tgt_ranges):
583          if stashed_blocks + xf.src_ranges.size() > max_allowed:
584            replaced_cmds.append(xf)
585            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
586
587      # Replace the commands in replaced_cmds with "new"s.
588      for cmd in replaced_cmds:
589        # It no longer uses any commands in "use_stash". Remove the def points
590        # for all those stashes.
591        for idx, sr in cmd.use_stash:
592          def_cmd = stashes[idx][1]
593          assert (idx, sr) in def_cmd.stash_before
594          def_cmd.stash_before.remove((idx, sr))
595          new_blocks += sr.size()
596
597        cmd.ConvertToNew()
598
599    print("  Total %d blocks are packed as new blocks due to insufficient "
600          "cache size." % (new_blocks,))
601
602  def ComputePatches(self, prefix):
603    print("Reticulating splines...")
604    diff_q = []
605    patch_num = 0
606    with open(prefix + ".new.dat", "wb") as new_f:
607      for xf in self.transfers:
608        if xf.style == "zero":
609          pass
610        elif xf.style == "new":
611          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
612            new_f.write(piece)
613        elif xf.style == "diff":
614          src = self.src.ReadRangeSet(xf.src_ranges)
615          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
616
617          # We can't compare src and tgt directly because they may have
618          # the same content but be broken up into blocks differently, eg:
619          #
620          #    ["he", "llo"]  vs  ["h", "ello"]
621          #
622          # We want those to compare equal, ideally without having to
623          # actually concatenate the strings (these may be tens of
624          # megabytes).
625
626          src_sha1 = sha1()
627          for p in src:
628            src_sha1.update(p)
629          tgt_sha1 = sha1()
630          tgt_size = 0
631          for p in tgt:
632            tgt_sha1.update(p)
633            tgt_size += len(p)
634
635          if src_sha1.digest() == tgt_sha1.digest():
636            # These are identical; we don't need to generate a patch,
637            # just issue copy commands on the device.
638            xf.style = "move"
639          else:
640            # For files in zip format (eg, APKs, JARs, etc.) we would
641            # like to use imgdiff -z if possible (because it usually
642            # produces significantly smaller patches than bsdiff).
643            # This is permissible if:
644            #
645            #  - the source and target files are monotonic (ie, the
646            #    data is stored with blocks in increasing order), and
647            #  - we haven't removed any blocks from the source set.
648            #
649            # If these conditions are satisfied then appending all the
650            # blocks in the set together in order will produce a valid
651            # zip file (plus possibly extra zeros in the last block),
652            # which is what imgdiff needs to operate.  (imgdiff is
653            # fine with extra zeros at the end of the file.)
654            imgdiff = (xf.intact and
655                       xf.tgt_name.split(".")[-1].lower()
656                       in ("apk", "jar", "zip"))
657            xf.style = "imgdiff" if imgdiff else "bsdiff"
658            diff_q.append((tgt_size, src, tgt, xf, patch_num))
659            patch_num += 1
660
661        else:
662          assert False, "unknown style " + xf.style
663
664    if diff_q:
665      if self.threads > 1:
666        print("Computing patches (using %d threads)..." % (self.threads,))
667      else:
668        print("Computing patches...")
669      diff_q.sort()
670
671      patches = [None] * patch_num
672
673      # TODO: Rewrite with multiprocessing.ThreadPool?
674      lock = threading.Lock()
675      def diff_worker():
676        while True:
677          with lock:
678            if not diff_q:
679              return
680            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
681          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
682          size = len(patch)
683          with lock:
684            patches[patchnum] = (patch, xf)
685            print("%10d %10d (%6.2f%%) %7s %s" % (
686                size, tgt_size, size * 100.0 / tgt_size, xf.style,
687                xf.tgt_name if xf.tgt_name == xf.src_name else (
688                    xf.tgt_name + " (from " + xf.src_name + ")")))
689
690      threads = [threading.Thread(target=diff_worker)
691                 for _ in range(self.threads)]
692      for th in threads:
693        th.start()
694      while threads:
695        threads.pop().join()
696    else:
697      patches = []
698
699    p = 0
700    with open(prefix + ".patch.dat", "wb") as patch_f:
701      for patch, xf in patches:
702        xf.patch_start = p
703        xf.patch_len = len(patch)
704        patch_f.write(patch)
705        p += len(patch)
706
707  def AssertSequenceGood(self):
708    # Simulate the sequences of transfers we will output, and check that:
709    # - we never read a block after writing it, and
710    # - we write every block we care about exactly once.
711
712    # Start with no blocks having been touched yet.
713    touched = RangeSet()
714
715    # Imagine processing the transfers in order.
716    for xf in self.transfers:
717      # Check that the input blocks for this transfer haven't yet been touched.
718
719      x = xf.src_ranges
720      if self.version >= 2:
721        for _, sr in xf.use_stash:
722          x = x.subtract(sr)
723
724      assert not touched.overlaps(x)
725      # Check that the output blocks for this transfer haven't yet been touched.
726      assert not touched.overlaps(xf.tgt_ranges)
727      # Touch all the blocks written by this transfer.
728      touched = touched.union(xf.tgt_ranges)
729
730    # Check that we've written every target block.
731    assert touched == self.tgt.care_map
732
733  def ImproveVertexSequence(self):
734    print("Improving vertex order...")
735
736    # At this point our digraph is acyclic; we reversed any edges that
737    # were backwards in the heuristically-generated sequence.  The
738    # previously-generated order is still acceptable, but we hope to
739    # find a better order that needs less memory for stashed data.
740    # Now we do a topological sort to generate a new vertex order,
741    # using a greedy algorithm to choose which vertex goes next
742    # whenever we have a choice.
743
744    # Make a copy of the edge set; this copy will get destroyed by the
745    # algorithm.
746    for xf in self.transfers:
747      xf.incoming = xf.goes_after.copy()
748      xf.outgoing = xf.goes_before.copy()
749
750    L = []   # the new vertex order
751
752    # S is the set of sources in the remaining graph; we always choose
753    # the one that leaves the least amount of stashed data after it's
754    # executed.
755    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
756         if not u.incoming]
757    heapq.heapify(S)
758
759    while S:
760      _, _, xf = heapq.heappop(S)
761      L.append(xf)
762      for u in xf.outgoing:
763        del u.incoming[xf]
764        if not u.incoming:
765          heapq.heappush(S, (u.NetStashChange(), u.order, u))
766
767    # if this fails then our graph had a cycle.
768    assert len(L) == len(self.transfers)
769
770    self.transfers = L
771    for i, xf in enumerate(L):
772      xf.order = i
773
774  def RemoveBackwardEdges(self):
775    print("Removing backward edges...")
776    in_order = 0
777    out_of_order = 0
778    lost_source = 0
779
780    for xf in self.transfers:
781      lost = 0
782      size = xf.src_ranges.size()
783      for u in xf.goes_before:
784        # xf should go before u
785        if xf.order < u.order:
786          # it does, hurray!
787          in_order += 1
788        else:
789          # it doesn't, boo.  trim the blocks that u writes from xf's
790          # source, so that xf can go after u.
791          out_of_order += 1
792          assert xf.src_ranges.overlaps(u.tgt_ranges)
793          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
794          xf.intact = False
795
796      if xf.style == "diff" and not xf.src_ranges:
797        # nothing left to diff from; treat as new data
798        xf.style = "new"
799
800      lost = size - xf.src_ranges.size()
801      lost_source += lost
802
803    print(("  %d/%d dependencies (%.2f%%) were violated; "
804           "%d source blocks removed.") %
805          (out_of_order, in_order + out_of_order,
806           (out_of_order * 100.0 / (in_order + out_of_order))
807           if (in_order + out_of_order) else 0.0,
808           lost_source))
809
810  def ReverseBackwardEdges(self):
811    print("Reversing backward edges...")
812    in_order = 0
813    out_of_order = 0
814    stashes = 0
815    stash_size = 0
816
817    for xf in self.transfers:
818      for u in xf.goes_before.copy():
819        # xf should go before u
820        if xf.order < u.order:
821          # it does, hurray!
822          in_order += 1
823        else:
824          # it doesn't, boo.  modify u to stash the blocks that it
825          # writes that xf wants to read, and then require u to go
826          # before xf.
827          out_of_order += 1
828
829          overlap = xf.src_ranges.intersect(u.tgt_ranges)
830          assert overlap
831
832          u.stash_before.append((stashes, overlap))
833          xf.use_stash.append((stashes, overlap))
834          stashes += 1
835          stash_size += overlap.size()
836
837          # reverse the edge direction; now xf must go after u
838          del xf.goes_before[u]
839          del u.goes_after[xf]
840          xf.goes_after[u] = None    # value doesn't matter
841          u.goes_before[xf] = None
842
843    print(("  %d/%d dependencies (%.2f%%) were violated; "
844           "%d source blocks stashed.") %
845          (out_of_order, in_order + out_of_order,
846           (out_of_order * 100.0 / (in_order + out_of_order))
847           if (in_order + out_of_order) else 0.0,
848           stash_size))
849
850  def FindVertexSequence(self):
851    print("Finding vertex sequence...")
852
853    # This is based on "A Fast & Effective Heuristic for the Feedback
854    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
855    # it as starting with the digraph G and moving all the vertices to
856    # be on a horizontal line in some order, trying to minimize the
857    # number of edges that end up pointing to the left.  Left-pointing
858    # edges will get removed to turn the digraph into a DAG.  In this
859    # case each edge has a weight which is the number of source blocks
860    # we'll lose if that edge is removed; we try to minimize the total
861    # weight rather than just the number of edges.
862
863    # Make a copy of the edge set; this copy will get destroyed by the
864    # algorithm.
865    for xf in self.transfers:
866      xf.incoming = xf.goes_after.copy()
867      xf.outgoing = xf.goes_before.copy()
868
869    # We use an OrderedDict instead of just a set so that the output
870    # is repeatable; otherwise it would depend on the hash values of
871    # the transfer objects.
872    G = OrderedDict()
873    for xf in self.transfers:
874      G[xf] = None
875    s1 = deque()  # the left side of the sequence, built from left to right
876    s2 = deque()  # the right side of the sequence, built from right to left
877
878    while G:
879
880      # Put all sinks at the end of the sequence.
881      while True:
882        sinks = [u for u in G if not u.outgoing]
883        if not sinks:
884          break
885        for u in sinks:
886          s2.appendleft(u)
887          del G[u]
888          for iu in u.incoming:
889            del iu.outgoing[u]
890
891      # Put all the sources at the beginning of the sequence.
892      while True:
893        sources = [u for u in G if not u.incoming]
894        if not sources:
895          break
896        for u in sources:
897          s1.append(u)
898          del G[u]
899          for iu in u.outgoing:
900            del iu.incoming[u]
901
902      if not G:
903        break
904
905      # Find the "best" vertex to put next.  "Best" is the one that
906      # maximizes the net difference in source blocks saved we get by
907      # pretending it's a source rather than a sink.
908
909      max_d = None
910      best_u = None
911      for u in G:
912        d = sum(u.outgoing.values()) - sum(u.incoming.values())
913        if best_u is None or d > max_d:
914          max_d = d
915          best_u = u
916
917      u = best_u
918      s1.append(u)
919      del G[u]
920      for iu in u.outgoing:
921        del iu.incoming[u]
922      for iu in u.incoming:
923        del iu.outgoing[u]
924
925    # Now record the sequence in the 'order' field of each transfer,
926    # and by rearranging self.transfers to be in the chosen sequence.
927
928    new_transfers = []
929    for x in itertools.chain(s1, s2):
930      x.order = len(new_transfers)
931      new_transfers.append(x)
932      del x.incoming
933      del x.outgoing
934
935    self.transfers = new_transfers
936
937  def GenerateDigraph(self):
938    print("Generating digraph...")
939    for a in self.transfers:
940      for b in self.transfers:
941        if a is b:
942          continue
943
944        # If the blocks written by A are read by B, then B needs to go before A.
945        i = a.tgt_ranges.intersect(b.src_ranges)
946        if i:
947          if b.src_name == "__ZERO":
948            # the cost of removing source blocks for the __ZERO domain
949            # is (nearly) zero.
950            size = 0
951          else:
952            size = i.size()
953          b.goes_before[a] = size
954          a.goes_after[b] = size
955
956  def FindTransfers(self):
957    """Parse the file_map to generate all the transfers."""
958
959    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
960                    split=False):
961      """Wrapper function for adding a Transfer().
962
963      For BBOTA v3, we need to stash source blocks for resumable feature.
964      However, with the growth of file size and the shrink of the cache
965      partition source blocks are too large to be stashed. If a file occupies
966      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
967      into smaller pieces by getting multiple Transfer()s.
968
969      The downside is that after splitting, we can no longer use imgdiff but
970      only bsdiff."""
971
972      MAX_BLOCKS_PER_DIFF_TRANSFER = 1024
973
974      # We care about diff transfers only.
975      if style != "diff" or not split:
976        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
977        return
978
979      # Change nothing for small files.
980      if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and
981          src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER):
982        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
983        return
984
985      pieces = 0
986      while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and
987             src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER):
988        tgt_split_name = "%s-%d" % (tgt_name, pieces)
989        src_split_name = "%s-%d" % (src_name, pieces)
990        tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER)
991        src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER)
992        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
993                 by_id)
994
995        tgt_ranges = tgt_ranges.subtract(tgt_first)
996        src_ranges = src_ranges.subtract(src_first)
997        pieces += 1
998
999      # Handle remaining blocks.
1000      if tgt_ranges.size() or src_ranges.size():
1001        # Must be both non-empty.
1002        assert tgt_ranges.size() and src_ranges.size()
1003        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1004        src_split_name = "%s-%d" % (src_name, pieces)
1005        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1006                 by_id)
1007
1008    empty = RangeSet()
1009    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1010      if tgt_fn == "__ZERO":
1011        # the special "__ZERO" domain is all the blocks not contained
1012        # in any file and that are filled with zeros.  We have a
1013        # special transfer style for zero blocks.
1014        src_ranges = self.src.file_map.get("__ZERO", empty)
1015        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1016                    "zero", self.transfers)
1017        continue
1018
1019      elif tgt_fn == "__COPY":
1020        # "__COPY" domain includes all the blocks not contained in any
1021        # file and that need to be copied unconditionally to the target.
1022        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1023        continue
1024
1025      elif tgt_fn in self.src.file_map:
1026        # Look for an exact pathname match in the source.
1027        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1028                    "diff", self.transfers, self.version >= 3)
1029        continue
1030
1031      b = os.path.basename(tgt_fn)
1032      if b in self.src_basenames:
1033        # Look for an exact basename match in the source.
1034        src_fn = self.src_basenames[b]
1035        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1036                    "diff", self.transfers, self.version >= 3)
1037        continue
1038
1039      b = re.sub("[0-9]+", "#", b)
1040      if b in self.src_numpatterns:
1041        # Look for a 'number pattern' match (a basename match after
1042        # all runs of digits are replaced by "#").  (This is useful
1043        # for .so files that contain version numbers in the filename
1044        # that get bumped.)
1045        src_fn = self.src_numpatterns[b]
1046        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1047                    "diff", self.transfers, self.version >= 3)
1048        continue
1049
1050      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1051
1052  def AbbreviateSourceNames(self):
1053    for k in self.src.file_map.keys():
1054      b = os.path.basename(k)
1055      self.src_basenames[b] = k
1056      b = re.sub("[0-9]+", "#", b)
1057      self.src_numpatterns[b] = k
1058
1059  @staticmethod
1060  def AssertPartition(total, seq):
1061    """Assert that all the RangeSets in 'seq' form a partition of the
1062    'total' RangeSet (ie, they are nonintersecting and their union
1063    equals 'total')."""
1064    so_far = RangeSet()
1065    for i in seq:
1066      assert not so_far.overlaps(i)
1067      so_far = so_far.union(i)
1068    assert so_far == total
1069