blockimgdiff.py revision 4fcb77e4e35db7fccb7a35ad650a85dc29bd2e73
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import common
20import heapq
21import itertools
22import multiprocessing
23import os
24import re
25import subprocess
26import threading
27import tempfile
28
29from rangelib import RangeSet
30
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34
35def compute_patch(src, tgt, imgdiff=False):
36  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
37  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
38  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
39  os.close(patchfd)
40
41  try:
42    with os.fdopen(srcfd, "wb") as f_src:
43      for p in src:
44        f_src.write(p)
45
46    with os.fdopen(tgtfd, "wb") as f_tgt:
47      for p in tgt:
48        f_tgt.write(p)
49    try:
50      os.unlink(patchfile)
51    except OSError:
52      pass
53    if imgdiff:
54      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
55                          stdout=open("/dev/null", "a"),
56                          stderr=subprocess.STDOUT)
57    else:
58      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
59
60    if p:
61      raise ValueError("diff failed: " + str(p))
62
63    with open(patchfile, "rb") as f:
64      return f.read()
65  finally:
66    try:
67      os.unlink(srcfile)
68      os.unlink(tgtfile)
69      os.unlink(patchfile)
70    except OSError:
71      pass
72
73
74class Image(object):
75  def ReadRangeSet(self, ranges):
76    raise NotImplementedError
77
78  def TotalSha1(self, include_clobbered_blocks=False):
79    raise NotImplementedError
80
81
82class EmptyImage(Image):
83  """A zero-length image."""
84  blocksize = 4096
85  care_map = RangeSet()
86  clobbered_blocks = RangeSet()
87  extended = RangeSet()
88  total_blocks = 0
89  file_map = {}
90  def ReadRangeSet(self, ranges):
91    return ()
92  def TotalSha1(self, include_clobbered_blocks=False):
93    # EmptyImage always carries empty clobbered_blocks, so
94    # include_clobbered_blocks can be ignored.
95    assert self.clobbered_blocks.size() == 0
96    return sha1().hexdigest()
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    padded = False
110    if partial > 0:
111      if trim:
112        self.data = self.data[:-partial]
113      elif pad:
114        self.data += '\0' * (self.blocksize - partial)
115        padded = True
116      else:
117        raise ValueError(("data for DataImage must be multiple of %d bytes "
118                          "unless trim or pad is specified") %
119                         (self.blocksize,))
120
121    assert len(self.data) % self.blocksize == 0
122
123    self.total_blocks = len(self.data) / self.blocksize
124    self.care_map = RangeSet(data=(0, self.total_blocks))
125    # When the last block is padded, we always write the whole block even for
126    # incremental OTAs. Because otherwise the last block may get skipped if
127    # unchanged for an incremental, but would fail the post-install
128    # verification if it has non-zero contents in the padding bytes.
129    # Bug: 23828506
130    if padded:
131      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
132    else:
133      clobbered_blocks = []
134    self.clobbered_blocks = clobbered_blocks
135    self.extended = RangeSet()
136
137    zero_blocks = []
138    nonzero_blocks = []
139    reference = '\0' * self.blocksize
140
141    for i in range(self.total_blocks-1 if padded else self.total_blocks):
142      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
143      if d == reference:
144        zero_blocks.append(i)
145        zero_blocks.append(i+1)
146      else:
147        nonzero_blocks.append(i)
148        nonzero_blocks.append(i+1)
149
150    assert zero_blocks or nonzero_blocks or clobbered_blocks
151
152    self.file_map = dict()
153    if zero_blocks:
154      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
155    if nonzero_blocks:
156      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
157    if clobbered_blocks:
158      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
159
160  def ReadRangeSet(self, ranges):
161    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
162
163  def TotalSha1(self, include_clobbered_blocks=False):
164    if not include_clobbered_blocks:
165      ranges = self.care_map.subtract(self.clobbered_blocks)
166      return sha1(self.ReadRangeSet(ranges)).hexdigest()
167    else:
168      return sha1(self.data).hexdigest()
169
170
171class Transfer(object):
172  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
173    self.tgt_name = tgt_name
174    self.src_name = src_name
175    self.tgt_ranges = tgt_ranges
176    self.src_ranges = src_ranges
177    self.style = style
178    self.intact = (getattr(tgt_ranges, "monotonic", False) and
179                   getattr(src_ranges, "monotonic", False))
180
181    # We use OrderedDict rather than dict so that the output is repeatable;
182    # otherwise it would depend on the hash values of the Transfer objects.
183    self.goes_before = OrderedDict()
184    self.goes_after = OrderedDict()
185
186    self.stash_before = []
187    self.use_stash = []
188
189    self.id = len(by_id)
190    by_id.append(self)
191
192  def NetStashChange(self):
193    return (sum(sr.size() for (_, sr) in self.stash_before) -
194            sum(sr.size() for (_, sr) in self.use_stash))
195
196  def ConvertToNew(self):
197    assert self.style != "new"
198    self.use_stash = []
199    self.style = "new"
200    self.src_ranges = RangeSet()
201
202  def __str__(self):
203    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
204            " to " + str(self.tgt_ranges) + ">")
205
206
207# BlockImageDiff works on two image objects.  An image object is
208# anything that provides the following attributes:
209#
210#    blocksize: the size in bytes of a block, currently must be 4096.
211#
212#    total_blocks: the total size of the partition/image, in blocks.
213#
214#    care_map: a RangeSet containing which blocks (in the range [0,
215#      total_blocks) we actually care about; i.e. which blocks contain
216#      data.
217#
218#    file_map: a dict that partitions the blocks contained in care_map
219#      into smaller domains that are useful for doing diffs on.
220#      (Typically a domain is a file, and the key in file_map is the
221#      pathname.)
222#
223#    clobbered_blocks: a RangeSet containing which blocks contain data
224#      but may be altered by the FS. They need to be excluded when
225#      verifying the partition integrity.
226#
227#    ReadRangeSet(): a function that takes a RangeSet and returns the
228#      data contained in the image blocks of that RangeSet.  The data
229#      is returned as a list or tuple of strings; concatenating the
230#      elements together should produce the requested data.
231#      Implementations are free to break up the data into list/tuple
232#      elements in any way that is convenient.
233#
234#    TotalSha1(): a function that returns (as a hex string) the SHA-1
235#      hash of all the data in the image (ie, all the blocks in the
236#      care_map minus clobbered_blocks, or including the clobbered
237#      blocks if include_clobbered_blocks is True).
238#
239# When creating a BlockImageDiff, the src image may be None, in which
240# case the list of transfers produced will never read from the
241# original image.
242
243class BlockImageDiff(object):
244  def __init__(self, tgt, src=None, threads=None, version=3):
245    if threads is None:
246      threads = multiprocessing.cpu_count() // 2
247      if threads == 0:
248        threads = 1
249    self.threads = threads
250    self.version = version
251    self.transfers = []
252    self.src_basenames = {}
253    self.src_numpatterns = {}
254
255    assert version in (1, 2, 3)
256
257    self.tgt = tgt
258    if src is None:
259      src = EmptyImage()
260    self.src = src
261
262    # The updater code that installs the patch always uses 4k blocks.
263    assert tgt.blocksize == 4096
264    assert src.blocksize == 4096
265
266    # The range sets in each filemap should comprise a partition of
267    # the care map.
268    self.AssertPartition(src.care_map, src.file_map.values())
269    self.AssertPartition(tgt.care_map, tgt.file_map.values())
270
271  def Compute(self, prefix):
272    # When looking for a source file to use as the diff input for a
273    # target file, we try:
274    #   1) an exact path match if available, otherwise
275    #   2) a exact basename match if available, otherwise
276    #   3) a basename match after all runs of digits are replaced by
277    #      "#" if available, otherwise
278    #   4) we have no source for this target.
279    self.AbbreviateSourceNames()
280    self.FindTransfers()
281
282    # Find the ordering dependencies among transfers (this is O(n^2)
283    # in the number of transfers).
284    self.GenerateDigraph()
285    # Find a sequence of transfers that satisfies as many ordering
286    # dependencies as possible (heuristically).
287    self.FindVertexSequence()
288    # Fix up the ordering dependencies that the sequence didn't
289    # satisfy.
290    if self.version == 1:
291      self.RemoveBackwardEdges()
292    else:
293      self.ReverseBackwardEdges()
294      self.ImproveVertexSequence()
295
296    # Ensure the runtime stash size is under the limit.
297    if self.version >= 2 and common.OPTIONS.cache_size is not None:
298      self.ReviseStashSize()
299
300    # Double-check our work.
301    self.AssertSequenceGood()
302
303    self.ComputePatches(prefix)
304    self.WriteTransfers(prefix)
305
306  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
307    data = source.ReadRangeSet(ranges)
308    ctx = sha1()
309
310    for p in data:
311      ctx.update(p)
312
313    return ctx.hexdigest()
314
315  def WriteTransfers(self, prefix):
316    out = []
317
318    total = 0
319
320    stashes = {}
321    stashed_blocks = 0
322    max_stashed_blocks = 0
323
324    free_stash_ids = []
325    next_stash_id = 0
326
327    for xf in self.transfers:
328
329      if self.version < 2:
330        assert not xf.stash_before
331        assert not xf.use_stash
332
333      for s, sr in xf.stash_before:
334        assert s not in stashes
335        if free_stash_ids:
336          sid = heapq.heappop(free_stash_ids)
337        else:
338          sid = next_stash_id
339          next_stash_id += 1
340        stashes[s] = sid
341        stashed_blocks += sr.size()
342        if self.version == 2:
343          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
344        else:
345          sh = self.HashBlocks(self.src, sr)
346          if sh in stashes:
347            stashes[sh] += 1
348          else:
349            stashes[sh] = 1
350            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
351
352      if stashed_blocks > max_stashed_blocks:
353        max_stashed_blocks = stashed_blocks
354
355      free_string = []
356
357      if self.version == 1:
358        src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else ""
359      elif self.version >= 2:
360
361        #   <# blocks> <src ranges>
362        #     OR
363        #   <# blocks> <src ranges> <src locs> <stash refs...>
364        #     OR
365        #   <# blocks> - <stash refs...>
366
367        size = xf.src_ranges.size()
368        src_str = [str(size)]
369
370        unstashed_src_ranges = xf.src_ranges
371        mapped_stashes = []
372        for s, sr in xf.use_stash:
373          sid = stashes.pop(s)
374          stashed_blocks -= sr.size()
375          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
376          sh = self.HashBlocks(self.src, sr)
377          sr = xf.src_ranges.map_within(sr)
378          mapped_stashes.append(sr)
379          if self.version == 2:
380            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
381            # A stash will be used only once. We need to free the stash
382            # immediately after the use, instead of waiting for the automatic
383            # clean-up at the end. Because otherwise it may take up extra space
384            # and lead to OTA failures.
385            # Bug: 23119955
386            free_string.append("free %d\n" % (sid,))
387          else:
388            assert sh in stashes
389            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
390            stashes[sh] -= 1
391            if stashes[sh] == 0:
392              free_string.append("free %s\n" % (sh))
393              stashes.pop(sh)
394          heapq.heappush(free_stash_ids, sid)
395
396        if unstashed_src_ranges:
397          src_str.insert(1, unstashed_src_ranges.to_string_raw())
398          if xf.use_stash:
399            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
400            src_str.insert(2, mapped_unstashed.to_string_raw())
401            mapped_stashes.append(mapped_unstashed)
402            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
403        else:
404          src_str.insert(1, "-")
405          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
406
407        src_str = " ".join(src_str)
408
409      # all versions:
410      #   zero <rangeset>
411      #   new <rangeset>
412      #   erase <rangeset>
413      #
414      # version 1:
415      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
416      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
417      #   move <src rangeset> <tgt rangeset>
418      #
419      # version 2:
420      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
421      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
422      #   move <tgt rangeset> <src_str>
423      #
424      # version 3:
425      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
426      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
427      #   move hash <tgt rangeset> <src_str>
428
429      tgt_size = xf.tgt_ranges.size()
430
431      if xf.style == "new":
432        assert xf.tgt_ranges
433        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
434        total += tgt_size
435      elif xf.style == "move":
436        assert xf.tgt_ranges
437        assert xf.src_ranges.size() == tgt_size
438        if xf.src_ranges != xf.tgt_ranges:
439          if self.version == 1:
440            out.append("%s %s %s\n" % (
441                xf.style,
442                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
443          elif self.version == 2:
444            out.append("%s %s %s\n" % (
445                xf.style,
446                xf.tgt_ranges.to_string_raw(), src_str))
447          elif self.version >= 3:
448            # take into account automatic stashing of overlapping blocks
449            if xf.src_ranges.overlaps(xf.tgt_ranges):
450              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
451              if temp_stash_usage > max_stashed_blocks:
452                max_stashed_blocks = temp_stash_usage
453
454            out.append("%s %s %s %s\n" % (
455                xf.style,
456                self.HashBlocks(self.tgt, xf.tgt_ranges),
457                xf.tgt_ranges.to_string_raw(), src_str))
458          total += tgt_size
459      elif xf.style in ("bsdiff", "imgdiff"):
460        assert xf.tgt_ranges
461        assert xf.src_ranges
462        if self.version == 1:
463          out.append("%s %d %d %s %s\n" % (
464              xf.style, xf.patch_start, xf.patch_len,
465              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
466        elif self.version == 2:
467          out.append("%s %d %d %s %s\n" % (
468              xf.style, xf.patch_start, xf.patch_len,
469              xf.tgt_ranges.to_string_raw(), src_str))
470        elif self.version >= 3:
471          # take into account automatic stashing of overlapping blocks
472          if xf.src_ranges.overlaps(xf.tgt_ranges):
473            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
474            if temp_stash_usage > max_stashed_blocks:
475              max_stashed_blocks = temp_stash_usage
476
477          out.append("%s %d %d %s %s %s %s\n" % (
478              xf.style,
479              xf.patch_start, xf.patch_len,
480              self.HashBlocks(self.src, xf.src_ranges),
481              self.HashBlocks(self.tgt, xf.tgt_ranges),
482              xf.tgt_ranges.to_string_raw(), src_str))
483        total += tgt_size
484      elif xf.style == "zero":
485        assert xf.tgt_ranges
486        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
487        if to_zero:
488          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
489          total += to_zero.size()
490      else:
491        raise ValueError("unknown transfer style '%s'\n" % xf.style)
492
493      if free_string:
494        out.append("".join(free_string))
495
496      if self.version >= 2 and common.OPTIONS.cache_size is not None:
497        # Sanity check: abort if we're going to need more stash space than
498        # the allowed size (cache_size * threshold). There are two purposes
499        # of having a threshold here. a) Part of the cache may have been
500        # occupied by some recovery logs. b) It will buy us some time to deal
501        # with the oversize issue.
502        cache_size = common.OPTIONS.cache_size
503        stash_threshold = common.OPTIONS.stash_threshold
504        max_allowed = cache_size * stash_threshold
505        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
506               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
507                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
508                   self.tgt.blocksize, max_allowed, cache_size,
509                   stash_threshold)
510
511    # Zero out extended blocks as a workaround for bug 20881595.
512    if self.tgt.extended:
513      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
514      total += self.tgt.extended.size()
515
516    # We erase all the blocks on the partition that a) don't contain useful
517    # data in the new image and b) will not be touched by dm-verity.
518    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
519    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
520    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
521    if new_dontcare:
522      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
523
524    out.insert(0, "%d\n" % (self.version,))   # format version number
525    out.insert(1, "%d\n" % (total,))
526    if self.version >= 2:
527      # version 2 only: after the total block count, we give the number
528      # of stash slots needed, and the maximum size needed (in blocks)
529      out.insert(2, str(next_stash_id) + "\n")
530      out.insert(3, str(max_stashed_blocks) + "\n")
531
532    with open(prefix + ".transfer.list", "wb") as f:
533      for i in out:
534        f.write(i)
535
536    if self.version >= 2:
537      max_stashed_size = max_stashed_blocks * self.tgt.blocksize
538      OPTIONS = common.OPTIONS
539      if OPTIONS.cache_size is not None:
540        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
541        print("max stashed blocks: %d  (%d bytes), "
542              "limit: %d bytes (%.2f%%)\n" % (
543              max_stashed_blocks, max_stashed_size, max_allowed,
544              max_stashed_size * 100.0 / max_allowed))
545      else:
546        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
547              max_stashed_blocks, max_stashed_size))
548
549  def ReviseStashSize(self):
550    print("Revising stash size...")
551    stashes = {}
552
553    # Create the map between a stash and its def/use points. For example, for a
554    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
555    for xf in self.transfers:
556      # Command xf defines (stores) all the stashes in stash_before.
557      for idx, sr in xf.stash_before:
558        stashes[idx] = (sr, xf)
559
560      # Record all the stashes command xf uses.
561      for idx, _ in xf.use_stash:
562        stashes[idx] += (xf,)
563
564    # Compute the maximum blocks available for stash based on /cache size and
565    # the threshold.
566    cache_size = common.OPTIONS.cache_size
567    stash_threshold = common.OPTIONS.stash_threshold
568    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
569
570    stashed_blocks = 0
571    new_blocks = 0
572
573    # Now go through all the commands. Compute the required stash size on the
574    # fly. If a command requires excess stash than available, it deletes the
575    # stash by replacing the command that uses the stash with a "new" command
576    # instead.
577    for xf in self.transfers:
578      replaced_cmds = []
579
580      # xf.stash_before generates explicit stash commands.
581      for idx, sr in xf.stash_before:
582        if stashed_blocks + sr.size() > max_allowed:
583          # We cannot stash this one for a later command. Find out the command
584          # that will use this stash and replace the command with "new".
585          use_cmd = stashes[idx][2]
586          replaced_cmds.append(use_cmd)
587          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
588        else:
589          stashed_blocks += sr.size()
590
591      # xf.use_stash generates free commands.
592      for _, sr in xf.use_stash:
593        stashed_blocks -= sr.size()
594
595      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
596      # ComputePatches(), they both have the style of "diff".
597      if xf.style == "diff" and self.version >= 3:
598        assert xf.tgt_ranges and xf.src_ranges
599        if xf.src_ranges.overlaps(xf.tgt_ranges):
600          if stashed_blocks + xf.src_ranges.size() > max_allowed:
601            replaced_cmds.append(xf)
602            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
603
604      # Replace the commands in replaced_cmds with "new"s.
605      for cmd in replaced_cmds:
606        # It no longer uses any commands in "use_stash". Remove the def points
607        # for all those stashes.
608        for idx, sr in cmd.use_stash:
609          def_cmd = stashes[idx][1]
610          assert (idx, sr) in def_cmd.stash_before
611          def_cmd.stash_before.remove((idx, sr))
612          new_blocks += sr.size()
613
614        cmd.ConvertToNew()
615
616    print("  Total %d blocks are packed as new blocks due to insufficient "
617          "cache size." % (new_blocks,))
618
619  def ComputePatches(self, prefix):
620    print("Reticulating splines...")
621    diff_q = []
622    patch_num = 0
623    with open(prefix + ".new.dat", "wb") as new_f:
624      for xf in self.transfers:
625        if xf.style == "zero":
626          pass
627        elif xf.style == "new":
628          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
629            new_f.write(piece)
630        elif xf.style == "diff":
631          src = self.src.ReadRangeSet(xf.src_ranges)
632          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
633
634          # We can't compare src and tgt directly because they may have
635          # the same content but be broken up into blocks differently, eg:
636          #
637          #    ["he", "llo"]  vs  ["h", "ello"]
638          #
639          # We want those to compare equal, ideally without having to
640          # actually concatenate the strings (these may be tens of
641          # megabytes).
642
643          src_sha1 = sha1()
644          for p in src:
645            src_sha1.update(p)
646          tgt_sha1 = sha1()
647          tgt_size = 0
648          for p in tgt:
649            tgt_sha1.update(p)
650            tgt_size += len(p)
651
652          if src_sha1.digest() == tgt_sha1.digest():
653            # These are identical; we don't need to generate a patch,
654            # just issue copy commands on the device.
655            xf.style = "move"
656          else:
657            # For files in zip format (eg, APKs, JARs, etc.) we would
658            # like to use imgdiff -z if possible (because it usually
659            # produces significantly smaller patches than bsdiff).
660            # This is permissible if:
661            #
662            #  - the source and target files are monotonic (ie, the
663            #    data is stored with blocks in increasing order), and
664            #  - we haven't removed any blocks from the source set.
665            #
666            # If these conditions are satisfied then appending all the
667            # blocks in the set together in order will produce a valid
668            # zip file (plus possibly extra zeros in the last block),
669            # which is what imgdiff needs to operate.  (imgdiff is
670            # fine with extra zeros at the end of the file.)
671            imgdiff = (xf.intact and
672                       xf.tgt_name.split(".")[-1].lower()
673                       in ("apk", "jar", "zip"))
674            xf.style = "imgdiff" if imgdiff else "bsdiff"
675            diff_q.append((tgt_size, src, tgt, xf, patch_num))
676            patch_num += 1
677
678        else:
679          assert False, "unknown style " + xf.style
680
681    if diff_q:
682      if self.threads > 1:
683        print("Computing patches (using %d threads)..." % (self.threads,))
684      else:
685        print("Computing patches...")
686      diff_q.sort()
687
688      patches = [None] * patch_num
689
690      # TODO: Rewrite with multiprocessing.ThreadPool?
691      lock = threading.Lock()
692      def diff_worker():
693        while True:
694          with lock:
695            if not diff_q:
696              return
697            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
698          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
699          size = len(patch)
700          with lock:
701            patches[patchnum] = (patch, xf)
702            print("%10d %10d (%6.2f%%) %7s %s" % (
703                size, tgt_size, size * 100.0 / tgt_size, xf.style,
704                xf.tgt_name if xf.tgt_name == xf.src_name else (
705                    xf.tgt_name + " (from " + xf.src_name + ")")))
706
707      threads = [threading.Thread(target=diff_worker)
708                 for _ in range(self.threads)]
709      for th in threads:
710        th.start()
711      while threads:
712        threads.pop().join()
713    else:
714      patches = []
715
716    p = 0
717    with open(prefix + ".patch.dat", "wb") as patch_f:
718      for patch, xf in patches:
719        xf.patch_start = p
720        xf.patch_len = len(patch)
721        patch_f.write(patch)
722        p += len(patch)
723
724  def AssertSequenceGood(self):
725    # Simulate the sequences of transfers we will output, and check that:
726    # - we never read a block after writing it, and
727    # - we write every block we care about exactly once.
728
729    # Start with no blocks having been touched yet.
730    touched = RangeSet()
731
732    # Imagine processing the transfers in order.
733    for xf in self.transfers:
734      # Check that the input blocks for this transfer haven't yet been touched.
735
736      x = xf.src_ranges
737      if self.version >= 2:
738        for _, sr in xf.use_stash:
739          x = x.subtract(sr)
740
741      assert not touched.overlaps(x)
742      # Check that the output blocks for this transfer haven't yet been touched.
743      assert not touched.overlaps(xf.tgt_ranges)
744      # Touch all the blocks written by this transfer.
745      touched = touched.union(xf.tgt_ranges)
746
747    # Check that we've written every target block.
748    assert touched == self.tgt.care_map
749
750  def ImproveVertexSequence(self):
751    print("Improving vertex order...")
752
753    # At this point our digraph is acyclic; we reversed any edges that
754    # were backwards in the heuristically-generated sequence.  The
755    # previously-generated order is still acceptable, but we hope to
756    # find a better order that needs less memory for stashed data.
757    # Now we do a topological sort to generate a new vertex order,
758    # using a greedy algorithm to choose which vertex goes next
759    # whenever we have a choice.
760
761    # Make a copy of the edge set; this copy will get destroyed by the
762    # algorithm.
763    for xf in self.transfers:
764      xf.incoming = xf.goes_after.copy()
765      xf.outgoing = xf.goes_before.copy()
766
767    L = []   # the new vertex order
768
769    # S is the set of sources in the remaining graph; we always choose
770    # the one that leaves the least amount of stashed data after it's
771    # executed.
772    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
773         if not u.incoming]
774    heapq.heapify(S)
775
776    while S:
777      _, _, xf = heapq.heappop(S)
778      L.append(xf)
779      for u in xf.outgoing:
780        del u.incoming[xf]
781        if not u.incoming:
782          heapq.heappush(S, (u.NetStashChange(), u.order, u))
783
784    # if this fails then our graph had a cycle.
785    assert len(L) == len(self.transfers)
786
787    self.transfers = L
788    for i, xf in enumerate(L):
789      xf.order = i
790
791  def RemoveBackwardEdges(self):
792    print("Removing backward edges...")
793    in_order = 0
794    out_of_order = 0
795    lost_source = 0
796
797    for xf in self.transfers:
798      lost = 0
799      size = xf.src_ranges.size()
800      for u in xf.goes_before:
801        # xf should go before u
802        if xf.order < u.order:
803          # it does, hurray!
804          in_order += 1
805        else:
806          # it doesn't, boo.  trim the blocks that u writes from xf's
807          # source, so that xf can go after u.
808          out_of_order += 1
809          assert xf.src_ranges.overlaps(u.tgt_ranges)
810          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
811          xf.intact = False
812
813      if xf.style == "diff" and not xf.src_ranges:
814        # nothing left to diff from; treat as new data
815        xf.style = "new"
816
817      lost = size - xf.src_ranges.size()
818      lost_source += lost
819
820    print(("  %d/%d dependencies (%.2f%%) were violated; "
821           "%d source blocks removed.") %
822          (out_of_order, in_order + out_of_order,
823           (out_of_order * 100.0 / (in_order + out_of_order))
824           if (in_order + out_of_order) else 0.0,
825           lost_source))
826
827  def ReverseBackwardEdges(self):
828    print("Reversing backward edges...")
829    in_order = 0
830    out_of_order = 0
831    stashes = 0
832    stash_size = 0
833
834    for xf in self.transfers:
835      for u in xf.goes_before.copy():
836        # xf should go before u
837        if xf.order < u.order:
838          # it does, hurray!
839          in_order += 1
840        else:
841          # it doesn't, boo.  modify u to stash the blocks that it
842          # writes that xf wants to read, and then require u to go
843          # before xf.
844          out_of_order += 1
845
846          overlap = xf.src_ranges.intersect(u.tgt_ranges)
847          assert overlap
848
849          u.stash_before.append((stashes, overlap))
850          xf.use_stash.append((stashes, overlap))
851          stashes += 1
852          stash_size += overlap.size()
853
854          # reverse the edge direction; now xf must go after u
855          del xf.goes_before[u]
856          del u.goes_after[xf]
857          xf.goes_after[u] = None    # value doesn't matter
858          u.goes_before[xf] = None
859
860    print(("  %d/%d dependencies (%.2f%%) were violated; "
861           "%d source blocks stashed.") %
862          (out_of_order, in_order + out_of_order,
863           (out_of_order * 100.0 / (in_order + out_of_order))
864           if (in_order + out_of_order) else 0.0,
865           stash_size))
866
867  def FindVertexSequence(self):
868    print("Finding vertex sequence...")
869
870    # This is based on "A Fast & Effective Heuristic for the Feedback
871    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
872    # it as starting with the digraph G and moving all the vertices to
873    # be on a horizontal line in some order, trying to minimize the
874    # number of edges that end up pointing to the left.  Left-pointing
875    # edges will get removed to turn the digraph into a DAG.  In this
876    # case each edge has a weight which is the number of source blocks
877    # we'll lose if that edge is removed; we try to minimize the total
878    # weight rather than just the number of edges.
879
880    # Make a copy of the edge set; this copy will get destroyed by the
881    # algorithm.
882    for xf in self.transfers:
883      xf.incoming = xf.goes_after.copy()
884      xf.outgoing = xf.goes_before.copy()
885
886    # We use an OrderedDict instead of just a set so that the output
887    # is repeatable; otherwise it would depend on the hash values of
888    # the transfer objects.
889    G = OrderedDict()
890    for xf in self.transfers:
891      G[xf] = None
892    s1 = deque()  # the left side of the sequence, built from left to right
893    s2 = deque()  # the right side of the sequence, built from right to left
894
895    while G:
896
897      # Put all sinks at the end of the sequence.
898      while True:
899        sinks = [u for u in G if not u.outgoing]
900        if not sinks:
901          break
902        for u in sinks:
903          s2.appendleft(u)
904          del G[u]
905          for iu in u.incoming:
906            del iu.outgoing[u]
907
908      # Put all the sources at the beginning of the sequence.
909      while True:
910        sources = [u for u in G if not u.incoming]
911        if not sources:
912          break
913        for u in sources:
914          s1.append(u)
915          del G[u]
916          for iu in u.outgoing:
917            del iu.incoming[u]
918
919      if not G:
920        break
921
922      # Find the "best" vertex to put next.  "Best" is the one that
923      # maximizes the net difference in source blocks saved we get by
924      # pretending it's a source rather than a sink.
925
926      max_d = None
927      best_u = None
928      for u in G:
929        d = sum(u.outgoing.values()) - sum(u.incoming.values())
930        if best_u is None or d > max_d:
931          max_d = d
932          best_u = u
933
934      u = best_u
935      s1.append(u)
936      del G[u]
937      for iu in u.outgoing:
938        del iu.incoming[u]
939      for iu in u.incoming:
940        del iu.outgoing[u]
941
942    # Now record the sequence in the 'order' field of each transfer,
943    # and by rearranging self.transfers to be in the chosen sequence.
944
945    new_transfers = []
946    for x in itertools.chain(s1, s2):
947      x.order = len(new_transfers)
948      new_transfers.append(x)
949      del x.incoming
950      del x.outgoing
951
952    self.transfers = new_transfers
953
954  def GenerateDigraph(self):
955    print("Generating digraph...")
956    for a in self.transfers:
957      for b in self.transfers:
958        if a is b:
959          continue
960
961        # If the blocks written by A are read by B, then B needs to go before A.
962        i = a.tgt_ranges.intersect(b.src_ranges)
963        if i:
964          if b.src_name == "__ZERO":
965            # the cost of removing source blocks for the __ZERO domain
966            # is (nearly) zero.
967            size = 0
968          else:
969            size = i.size()
970          b.goes_before[a] = size
971          a.goes_after[b] = size
972
973  def FindTransfers(self):
974    """Parse the file_map to generate all the transfers."""
975
976    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
977                    split=False):
978      """Wrapper function for adding a Transfer().
979
980      For BBOTA v3, we need to stash source blocks for resumable feature.
981      However, with the growth of file size and the shrink of the cache
982      partition source blocks are too large to be stashed. If a file occupies
983      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
984      into smaller pieces by getting multiple Transfer()s.
985
986      The downside is that after splitting, we can no longer use imgdiff but
987      only bsdiff."""
988
989      MAX_BLOCKS_PER_DIFF_TRANSFER = 1024
990
991      # We care about diff transfers only.
992      if style != "diff" or not split:
993        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
994        return
995
996      # Change nothing for small files.
997      if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and
998          src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER):
999        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1000        return
1001
1002      pieces = 0
1003      while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and
1004             src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER):
1005        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1006        src_split_name = "%s-%d" % (src_name, pieces)
1007        tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER)
1008        src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER)
1009        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
1010                 by_id)
1011
1012        tgt_ranges = tgt_ranges.subtract(tgt_first)
1013        src_ranges = src_ranges.subtract(src_first)
1014        pieces += 1
1015
1016      # Handle remaining blocks.
1017      if tgt_ranges.size() or src_ranges.size():
1018        # Must be both non-empty.
1019        assert tgt_ranges.size() and src_ranges.size()
1020        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1021        src_split_name = "%s-%d" % (src_name, pieces)
1022        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1023                 by_id)
1024
1025    empty = RangeSet()
1026    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1027      if tgt_fn == "__ZERO":
1028        # the special "__ZERO" domain is all the blocks not contained
1029        # in any file and that are filled with zeros.  We have a
1030        # special transfer style for zero blocks.
1031        src_ranges = self.src.file_map.get("__ZERO", empty)
1032        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1033                    "zero", self.transfers)
1034        continue
1035
1036      elif tgt_fn == "__COPY":
1037        # "__COPY" domain includes all the blocks not contained in any
1038        # file and that need to be copied unconditionally to the target.
1039        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1040        continue
1041
1042      elif tgt_fn in self.src.file_map:
1043        # Look for an exact pathname match in the source.
1044        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1045                    "diff", self.transfers, self.version >= 3)
1046        continue
1047
1048      b = os.path.basename(tgt_fn)
1049      if b in self.src_basenames:
1050        # Look for an exact basename match in the source.
1051        src_fn = self.src_basenames[b]
1052        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1053                    "diff", self.transfers, self.version >= 3)
1054        continue
1055
1056      b = re.sub("[0-9]+", "#", b)
1057      if b in self.src_numpatterns:
1058        # Look for a 'number pattern' match (a basename match after
1059        # all runs of digits are replaced by "#").  (This is useful
1060        # for .so files that contain version numbers in the filename
1061        # that get bumped.)
1062        src_fn = self.src_numpatterns[b]
1063        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1064                    "diff", self.transfers, self.version >= 3)
1065        continue
1066
1067      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1068
1069  def AbbreviateSourceNames(self):
1070    for k in self.src.file_map.keys():
1071      b = os.path.basename(k)
1072      self.src_basenames[b] = k
1073      b = re.sub("[0-9]+", "#", b)
1074      self.src_numpatterns[b] = k
1075
1076  @staticmethod
1077  def AssertPartition(total, seq):
1078    """Assert that all the RangeSets in 'seq' form a partition of the
1079    'total' RangeSet (ie, they are nonintersecting and their union
1080    equals 'total')."""
1081    so_far = RangeSet()
1082    for i in seq:
1083      assert not so_far.overlaps(i)
1084      so_far = so_far.union(i)
1085    assert so_far == total
1086