blockimgdiff.py revision ebe39a096a1c8c618f0bc34e1a5c1ab4f8d1ee44
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import common
20import heapq
21import itertools
22import multiprocessing
23import os
24import re
25import subprocess
26import threading
27import tempfile
28
29from rangelib import RangeSet
30
31
32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
33
34
35def compute_patch(src, tgt, imgdiff=False):
36  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
37  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
38  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
39  os.close(patchfd)
40
41  try:
42    with os.fdopen(srcfd, "wb") as f_src:
43      for p in src:
44        f_src.write(p)
45
46    with os.fdopen(tgtfd, "wb") as f_tgt:
47      for p in tgt:
48        f_tgt.write(p)
49    try:
50      os.unlink(patchfile)
51    except OSError:
52      pass
53    if imgdiff:
54      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
55                          stdout=open("/dev/null", "a"),
56                          stderr=subprocess.STDOUT)
57    else:
58      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
59
60    if p:
61      raise ValueError("diff failed: " + str(p))
62
63    with open(patchfile, "rb") as f:
64      return f.read()
65  finally:
66    try:
67      os.unlink(srcfile)
68      os.unlink(tgtfile)
69      os.unlink(patchfile)
70    except OSError:
71      pass
72
73
74class Image(object):
75  def ReadRangeSet(self, ranges):
76    raise NotImplementedError
77
78  def TotalSha1(self, include_clobbered_blocks=False):
79    raise NotImplementedError
80
81
82class EmptyImage(Image):
83  """A zero-length image."""
84  blocksize = 4096
85  care_map = RangeSet()
86  clobbered_blocks = RangeSet()
87  extended = RangeSet()
88  total_blocks = 0
89  file_map = {}
90  def ReadRangeSet(self, ranges):
91    return ()
92  def TotalSha1(self, include_clobbered_blocks=False):
93    # EmptyImage always carries empty clobbered_blocks, so
94    # include_clobbered_blocks can be ignored.
95    assert self.clobbered_blocks.size() == 0
96    return sha1().hexdigest()
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    padded = False
110    if partial > 0:
111      if trim:
112        self.data = self.data[:-partial]
113      elif pad:
114        self.data += '\0' * (self.blocksize - partial)
115        padded = True
116      else:
117        raise ValueError(("data for DataImage must be multiple of %d bytes "
118                          "unless trim or pad is specified") %
119                         (self.blocksize,))
120
121    assert len(self.data) % self.blocksize == 0
122
123    self.total_blocks = len(self.data) / self.blocksize
124    self.care_map = RangeSet(data=(0, self.total_blocks))
125    # When the last block is padded, we always write the whole block even for
126    # incremental OTAs. Because otherwise the last block may get skipped if
127    # unchanged for an incremental, but would fail the post-install
128    # verification if it has non-zero contents in the padding bytes.
129    # Bug: 23828506
130    if padded:
131      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
132    else:
133      clobbered_blocks = []
134    self.clobbered_blocks = clobbered_blocks
135    self.extended = RangeSet()
136
137    zero_blocks = []
138    nonzero_blocks = []
139    reference = '\0' * self.blocksize
140
141    for i in range(self.total_blocks-1 if padded else self.total_blocks):
142      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
143      if d == reference:
144        zero_blocks.append(i)
145        zero_blocks.append(i+1)
146      else:
147        nonzero_blocks.append(i)
148        nonzero_blocks.append(i+1)
149
150    assert zero_blocks or nonzero_blocks or clobbered_blocks
151
152    self.file_map = dict()
153    if zero_blocks:
154      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
155    if nonzero_blocks:
156      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
157    if clobbered_blocks:
158      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
159
160  def ReadRangeSet(self, ranges):
161    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
162
163  def TotalSha1(self, include_clobbered_blocks=False):
164    if not include_clobbered_blocks:
165      ranges = self.care_map.subtract(self.clobbered_blocks)
166      return sha1(self.ReadRangeSet(ranges)).hexdigest()
167    else:
168      return sha1(self.data).hexdigest()
169
170
171class Transfer(object):
172  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
173    self.tgt_name = tgt_name
174    self.src_name = src_name
175    self.tgt_ranges = tgt_ranges
176    self.src_ranges = src_ranges
177    self.style = style
178    self.intact = (getattr(tgt_ranges, "monotonic", False) and
179                   getattr(src_ranges, "monotonic", False))
180
181    # We use OrderedDict rather than dict so that the output is repeatable;
182    # otherwise it would depend on the hash values of the Transfer objects.
183    self.goes_before = OrderedDict()
184    self.goes_after = OrderedDict()
185
186    self.stash_before = []
187    self.use_stash = []
188
189    self.id = len(by_id)
190    by_id.append(self)
191
192  def NetStashChange(self):
193    return (sum(sr.size() for (_, sr) in self.stash_before) -
194            sum(sr.size() for (_, sr) in self.use_stash))
195
196  def ConvertToNew(self):
197    assert self.style != "new"
198    self.use_stash = []
199    self.style = "new"
200    self.src_ranges = RangeSet()
201
202  def __str__(self):
203    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
204            " to " + str(self.tgt_ranges) + ">")
205
206
207# BlockImageDiff works on two image objects.  An image object is
208# anything that provides the following attributes:
209#
210#    blocksize: the size in bytes of a block, currently must be 4096.
211#
212#    total_blocks: the total size of the partition/image, in blocks.
213#
214#    care_map: a RangeSet containing which blocks (in the range [0,
215#      total_blocks) we actually care about; i.e. which blocks contain
216#      data.
217#
218#    file_map: a dict that partitions the blocks contained in care_map
219#      into smaller domains that are useful for doing diffs on.
220#      (Typically a domain is a file, and the key in file_map is the
221#      pathname.)
222#
223#    clobbered_blocks: a RangeSet containing which blocks contain data
224#      but may be altered by the FS. They need to be excluded when
225#      verifying the partition integrity.
226#
227#    ReadRangeSet(): a function that takes a RangeSet and returns the
228#      data contained in the image blocks of that RangeSet.  The data
229#      is returned as a list or tuple of strings; concatenating the
230#      elements together should produce the requested data.
231#      Implementations are free to break up the data into list/tuple
232#      elements in any way that is convenient.
233#
234#    TotalSha1(): a function that returns (as a hex string) the SHA-1
235#      hash of all the data in the image (ie, all the blocks in the
236#      care_map minus clobbered_blocks, or including the clobbered
237#      blocks if include_clobbered_blocks is True).
238#
239# When creating a BlockImageDiff, the src image may be None, in which
240# case the list of transfers produced will never read from the
241# original image.
242
243class BlockImageDiff(object):
244  def __init__(self, tgt, src=None, threads=None, version=4):
245    if threads is None:
246      threads = multiprocessing.cpu_count() // 2
247      if threads == 0:
248        threads = 1
249    self.threads = threads
250    self.version = version
251    self.transfers = []
252    self.src_basenames = {}
253    self.src_numpatterns = {}
254
255    assert version in (1, 2, 3, 4)
256
257    self.tgt = tgt
258    if src is None:
259      src = EmptyImage()
260    self.src = src
261
262    # The updater code that installs the patch always uses 4k blocks.
263    assert tgt.blocksize == 4096
264    assert src.blocksize == 4096
265
266    # The range sets in each filemap should comprise a partition of
267    # the care map.
268    self.AssertPartition(src.care_map, src.file_map.values())
269    self.AssertPartition(tgt.care_map, tgt.file_map.values())
270
271  def Compute(self, prefix):
272    # When looking for a source file to use as the diff input for a
273    # target file, we try:
274    #   1) an exact path match if available, otherwise
275    #   2) a exact basename match if available, otherwise
276    #   3) a basename match after all runs of digits are replaced by
277    #      "#" if available, otherwise
278    #   4) we have no source for this target.
279    self.AbbreviateSourceNames()
280    self.FindTransfers()
281
282    # Find the ordering dependencies among transfers (this is O(n^2)
283    # in the number of transfers).
284    self.GenerateDigraph()
285    # Find a sequence of transfers that satisfies as many ordering
286    # dependencies as possible (heuristically).
287    self.FindVertexSequence()
288    # Fix up the ordering dependencies that the sequence didn't
289    # satisfy.
290    if self.version == 1:
291      self.RemoveBackwardEdges()
292    else:
293      self.ReverseBackwardEdges()
294      self.ImproveVertexSequence()
295
296    # Ensure the runtime stash size is under the limit.
297    if self.version >= 2 and common.OPTIONS.cache_size is not None:
298      self.ReviseStashSize()
299
300    # Double-check our work.
301    self.AssertSequenceGood()
302
303    self.ComputePatches(prefix)
304    self.WriteTransfers(prefix)
305
306  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
307    data = source.ReadRangeSet(ranges)
308    ctx = sha1()
309
310    for p in data:
311      ctx.update(p)
312
313    return ctx.hexdigest()
314
315  def WriteTransfers(self, prefix):
316    out = []
317
318    total = 0
319
320    stashes = {}
321    stashed_blocks = 0
322    max_stashed_blocks = 0
323
324    free_stash_ids = []
325    next_stash_id = 0
326
327    for xf in self.transfers:
328
329      if self.version < 2:
330        assert not xf.stash_before
331        assert not xf.use_stash
332
333      for s, sr in xf.stash_before:
334        assert s not in stashes
335        if free_stash_ids:
336          sid = heapq.heappop(free_stash_ids)
337        else:
338          sid = next_stash_id
339          next_stash_id += 1
340        stashes[s] = sid
341        if self.version == 2:
342          stashed_blocks += sr.size()
343          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
344        else:
345          sh = self.HashBlocks(self.src, sr)
346          if sh in stashes:
347            stashes[sh] += 1
348          else:
349            stashes[sh] = 1
350            stashed_blocks += sr.size()
351            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
352
353      if stashed_blocks > max_stashed_blocks:
354        max_stashed_blocks = stashed_blocks
355
356      free_string = []
357      free_size = 0
358
359      if self.version == 1:
360        src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else ""
361      elif self.version >= 2:
362
363        #   <# blocks> <src ranges>
364        #     OR
365        #   <# blocks> <src ranges> <src locs> <stash refs...>
366        #     OR
367        #   <# blocks> - <stash refs...>
368
369        size = xf.src_ranges.size()
370        src_str = [str(size)]
371
372        unstashed_src_ranges = xf.src_ranges
373        mapped_stashes = []
374        for s, sr in xf.use_stash:
375          sid = stashes.pop(s)
376          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
377          sh = self.HashBlocks(self.src, sr)
378          sr = xf.src_ranges.map_within(sr)
379          mapped_stashes.append(sr)
380          if self.version == 2:
381            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
382            # A stash will be used only once. We need to free the stash
383            # immediately after the use, instead of waiting for the automatic
384            # clean-up at the end. Because otherwise it may take up extra space
385            # and lead to OTA failures.
386            # Bug: 23119955
387            free_string.append("free %d\n" % (sid,))
388            free_size += sr.size()
389          else:
390            assert sh in stashes
391            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
392            stashes[sh] -= 1
393            if stashes[sh] == 0:
394              free_size += sr.size()
395              free_string.append("free %s\n" % (sh))
396              stashes.pop(sh)
397          heapq.heappush(free_stash_ids, sid)
398
399        if unstashed_src_ranges:
400          src_str.insert(1, unstashed_src_ranges.to_string_raw())
401          if xf.use_stash:
402            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
403            src_str.insert(2, mapped_unstashed.to_string_raw())
404            mapped_stashes.append(mapped_unstashed)
405            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
406        else:
407          src_str.insert(1, "-")
408          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
409
410        src_str = " ".join(src_str)
411
412      # all versions:
413      #   zero <rangeset>
414      #   new <rangeset>
415      #   erase <rangeset>
416      #
417      # version 1:
418      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
419      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
420      #   move <src rangeset> <tgt rangeset>
421      #
422      # version 2:
423      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
424      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
425      #   move <tgt rangeset> <src_str>
426      #
427      # version 3:
428      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
429      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
430      #   move hash <tgt rangeset> <src_str>
431
432      tgt_size = xf.tgt_ranges.size()
433
434      if xf.style == "new":
435        assert xf.tgt_ranges
436        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
437        total += tgt_size
438      elif xf.style == "move":
439        assert xf.tgt_ranges
440        assert xf.src_ranges.size() == tgt_size
441        if xf.src_ranges != xf.tgt_ranges:
442          if self.version == 1:
443            out.append("%s %s %s\n" % (
444                xf.style,
445                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
446          elif self.version == 2:
447            out.append("%s %s %s\n" % (
448                xf.style,
449                xf.tgt_ranges.to_string_raw(), src_str))
450          elif self.version >= 3:
451            # take into account automatic stashing of overlapping blocks
452            if xf.src_ranges.overlaps(xf.tgt_ranges):
453              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
454              if temp_stash_usage > max_stashed_blocks:
455                max_stashed_blocks = temp_stash_usage
456
457            out.append("%s %s %s %s\n" % (
458                xf.style,
459                self.HashBlocks(self.tgt, xf.tgt_ranges),
460                xf.tgt_ranges.to_string_raw(), src_str))
461          total += tgt_size
462      elif xf.style in ("bsdiff", "imgdiff"):
463        assert xf.tgt_ranges
464        assert xf.src_ranges
465        if self.version == 1:
466          out.append("%s %d %d %s %s\n" % (
467              xf.style, xf.patch_start, xf.patch_len,
468              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
469        elif self.version == 2:
470          out.append("%s %d %d %s %s\n" % (
471              xf.style, xf.patch_start, xf.patch_len,
472              xf.tgt_ranges.to_string_raw(), src_str))
473        elif self.version >= 3:
474          # take into account automatic stashing of overlapping blocks
475          if xf.src_ranges.overlaps(xf.tgt_ranges):
476            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
477            if temp_stash_usage > max_stashed_blocks:
478              max_stashed_blocks = temp_stash_usage
479
480          out.append("%s %d %d %s %s %s %s\n" % (
481              xf.style,
482              xf.patch_start, xf.patch_len,
483              self.HashBlocks(self.src, xf.src_ranges),
484              self.HashBlocks(self.tgt, xf.tgt_ranges),
485              xf.tgt_ranges.to_string_raw(), src_str))
486        total += tgt_size
487      elif xf.style == "zero":
488        assert xf.tgt_ranges
489        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
490        if to_zero:
491          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
492          total += to_zero.size()
493      else:
494        raise ValueError("unknown transfer style '%s'\n" % xf.style)
495
496      if free_string:
497        out.append("".join(free_string))
498        stashed_blocks -= free_size
499
500      if self.version >= 2 and common.OPTIONS.cache_size is not None:
501        # Sanity check: abort if we're going to need more stash space than
502        # the allowed size (cache_size * threshold). There are two purposes
503        # of having a threshold here. a) Part of the cache may have been
504        # occupied by some recovery logs. b) It will buy us some time to deal
505        # with the oversize issue.
506        cache_size = common.OPTIONS.cache_size
507        stash_threshold = common.OPTIONS.stash_threshold
508        max_allowed = cache_size * stash_threshold
509        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
510               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
511                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
512                   self.tgt.blocksize, max_allowed, cache_size,
513                   stash_threshold)
514
515    # Zero out extended blocks as a workaround for bug 20881595.
516    if self.tgt.extended:
517      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
518      total += self.tgt.extended.size()
519
520    # We erase all the blocks on the partition that a) don't contain useful
521    # data in the new image and b) will not be touched by dm-verity.
522    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
523    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
524    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
525    if new_dontcare:
526      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
527
528    out.insert(0, "%d\n" % (self.version,))   # format version number
529    out.insert(1, "%d\n" % (total,))
530    if self.version >= 2:
531      # version 2 only: after the total block count, we give the number
532      # of stash slots needed, and the maximum size needed (in blocks)
533      out.insert(2, str(next_stash_id) + "\n")
534      out.insert(3, str(max_stashed_blocks) + "\n")
535
536    with open(prefix + ".transfer.list", "wb") as f:
537      for i in out:
538        f.write(i)
539
540    if self.version >= 2:
541      max_stashed_size = max_stashed_blocks * self.tgt.blocksize
542      OPTIONS = common.OPTIONS
543      if OPTIONS.cache_size is not None:
544        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
545        print("max stashed blocks: %d  (%d bytes), "
546              "limit: %d bytes (%.2f%%)\n" % (
547              max_stashed_blocks, max_stashed_size, max_allowed,
548              max_stashed_size * 100.0 / max_allowed))
549      else:
550        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
551              max_stashed_blocks, max_stashed_size))
552
553  def ReviseStashSize(self):
554    print("Revising stash size...")
555    stashes = {}
556
557    # Create the map between a stash and its def/use points. For example, for a
558    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
559    for xf in self.transfers:
560      # Command xf defines (stores) all the stashes in stash_before.
561      for idx, sr in xf.stash_before:
562        stashes[idx] = (sr, xf)
563
564      # Record all the stashes command xf uses.
565      for idx, _ in xf.use_stash:
566        stashes[idx] += (xf,)
567
568    # Compute the maximum blocks available for stash based on /cache size and
569    # the threshold.
570    cache_size = common.OPTIONS.cache_size
571    stash_threshold = common.OPTIONS.stash_threshold
572    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
573
574    stashed_blocks = 0
575    new_blocks = 0
576
577    # Now go through all the commands. Compute the required stash size on the
578    # fly. If a command requires excess stash than available, it deletes the
579    # stash by replacing the command that uses the stash with a "new" command
580    # instead.
581    for xf in self.transfers:
582      replaced_cmds = []
583
584      # xf.stash_before generates explicit stash commands.
585      for idx, sr in xf.stash_before:
586        if stashed_blocks + sr.size() > max_allowed:
587          # We cannot stash this one for a later command. Find out the command
588          # that will use this stash and replace the command with "new".
589          use_cmd = stashes[idx][2]
590          replaced_cmds.append(use_cmd)
591          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
592        else:
593          stashed_blocks += sr.size()
594
595      # xf.use_stash generates free commands.
596      for _, sr in xf.use_stash:
597        stashed_blocks -= sr.size()
598
599      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
600      # ComputePatches(), they both have the style of "diff".
601      if xf.style == "diff" and self.version >= 3:
602        assert xf.tgt_ranges and xf.src_ranges
603        if xf.src_ranges.overlaps(xf.tgt_ranges):
604          if stashed_blocks + xf.src_ranges.size() > max_allowed:
605            replaced_cmds.append(xf)
606            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
607
608      # Replace the commands in replaced_cmds with "new"s.
609      for cmd in replaced_cmds:
610        # It no longer uses any commands in "use_stash". Remove the def points
611        # for all those stashes.
612        for idx, sr in cmd.use_stash:
613          def_cmd = stashes[idx][1]
614          assert (idx, sr) in def_cmd.stash_before
615          def_cmd.stash_before.remove((idx, sr))
616
617        # Add up blocks that violates space limit and print total number to
618        # screen later.
619        new_blocks += cmd.tgt_ranges.size()
620        cmd.ConvertToNew()
621
622    num_of_bytes = new_blocks * self.tgt.blocksize
623    print("  Total %d blocks (%d bytes) are packed as new blocks due to "
624          "insufficient cache size." % (new_blocks, num_of_bytes))
625
626  def ComputePatches(self, prefix):
627    print("Reticulating splines...")
628    diff_q = []
629    patch_num = 0
630    with open(prefix + ".new.dat", "wb") as new_f:
631      for xf in self.transfers:
632        if xf.style == "zero":
633          pass
634        elif xf.style == "new":
635          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
636            new_f.write(piece)
637        elif xf.style == "diff":
638          src = self.src.ReadRangeSet(xf.src_ranges)
639          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
640
641          # We can't compare src and tgt directly because they may have
642          # the same content but be broken up into blocks differently, eg:
643          #
644          #    ["he", "llo"]  vs  ["h", "ello"]
645          #
646          # We want those to compare equal, ideally without having to
647          # actually concatenate the strings (these may be tens of
648          # megabytes).
649
650          src_sha1 = sha1()
651          for p in src:
652            src_sha1.update(p)
653          tgt_sha1 = sha1()
654          tgt_size = 0
655          for p in tgt:
656            tgt_sha1.update(p)
657            tgt_size += len(p)
658
659          if src_sha1.digest() == tgt_sha1.digest():
660            # These are identical; we don't need to generate a patch,
661            # just issue copy commands on the device.
662            xf.style = "move"
663          else:
664            # For files in zip format (eg, APKs, JARs, etc.) we would
665            # like to use imgdiff -z if possible (because it usually
666            # produces significantly smaller patches than bsdiff).
667            # This is permissible if:
668            #
669            #  - the source and target files are monotonic (ie, the
670            #    data is stored with blocks in increasing order), and
671            #  - we haven't removed any blocks from the source set.
672            #
673            # If these conditions are satisfied then appending all the
674            # blocks in the set together in order will produce a valid
675            # zip file (plus possibly extra zeros in the last block),
676            # which is what imgdiff needs to operate.  (imgdiff is
677            # fine with extra zeros at the end of the file.)
678            imgdiff = (xf.intact and
679                       xf.tgt_name.split(".")[-1].lower()
680                       in ("apk", "jar", "zip"))
681            xf.style = "imgdiff" if imgdiff else "bsdiff"
682            diff_q.append((tgt_size, src, tgt, xf, patch_num))
683            patch_num += 1
684
685        else:
686          assert False, "unknown style " + xf.style
687
688    if diff_q:
689      if self.threads > 1:
690        print("Computing patches (using %d threads)..." % (self.threads,))
691      else:
692        print("Computing patches...")
693      diff_q.sort()
694
695      patches = [None] * patch_num
696
697      # TODO: Rewrite with multiprocessing.ThreadPool?
698      lock = threading.Lock()
699      def diff_worker():
700        while True:
701          with lock:
702            if not diff_q:
703              return
704            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
705          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
706          size = len(patch)
707          with lock:
708            patches[patchnum] = (patch, xf)
709            print("%10d %10d (%6.2f%%) %7s %s" % (
710                size, tgt_size, size * 100.0 / tgt_size, xf.style,
711                xf.tgt_name if xf.tgt_name == xf.src_name else (
712                    xf.tgt_name + " (from " + xf.src_name + ")")))
713
714      threads = [threading.Thread(target=diff_worker)
715                 for _ in range(self.threads)]
716      for th in threads:
717        th.start()
718      while threads:
719        threads.pop().join()
720    else:
721      patches = []
722
723    p = 0
724    with open(prefix + ".patch.dat", "wb") as patch_f:
725      for patch, xf in patches:
726        xf.patch_start = p
727        xf.patch_len = len(patch)
728        patch_f.write(patch)
729        p += len(patch)
730
731  def AssertSequenceGood(self):
732    # Simulate the sequences of transfers we will output, and check that:
733    # - we never read a block after writing it, and
734    # - we write every block we care about exactly once.
735
736    # Start with no blocks having been touched yet.
737    touched = RangeSet()
738
739    # Imagine processing the transfers in order.
740    for xf in self.transfers:
741      # Check that the input blocks for this transfer haven't yet been touched.
742
743      x = xf.src_ranges
744      if self.version >= 2:
745        for _, sr in xf.use_stash:
746          x = x.subtract(sr)
747
748      assert not touched.overlaps(x)
749      # Check that the output blocks for this transfer haven't yet been touched.
750      assert not touched.overlaps(xf.tgt_ranges)
751      # Touch all the blocks written by this transfer.
752      touched = touched.union(xf.tgt_ranges)
753
754    # Check that we've written every target block.
755    assert touched == self.tgt.care_map
756
757  def ImproveVertexSequence(self):
758    print("Improving vertex order...")
759
760    # At this point our digraph is acyclic; we reversed any edges that
761    # were backwards in the heuristically-generated sequence.  The
762    # previously-generated order is still acceptable, but we hope to
763    # find a better order that needs less memory for stashed data.
764    # Now we do a topological sort to generate a new vertex order,
765    # using a greedy algorithm to choose which vertex goes next
766    # whenever we have a choice.
767
768    # Make a copy of the edge set; this copy will get destroyed by the
769    # algorithm.
770    for xf in self.transfers:
771      xf.incoming = xf.goes_after.copy()
772      xf.outgoing = xf.goes_before.copy()
773
774    L = []   # the new vertex order
775
776    # S is the set of sources in the remaining graph; we always choose
777    # the one that leaves the least amount of stashed data after it's
778    # executed.
779    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
780         if not u.incoming]
781    heapq.heapify(S)
782
783    while S:
784      _, _, xf = heapq.heappop(S)
785      L.append(xf)
786      for u in xf.outgoing:
787        del u.incoming[xf]
788        if not u.incoming:
789          heapq.heappush(S, (u.NetStashChange(), u.order, u))
790
791    # if this fails then our graph had a cycle.
792    assert len(L) == len(self.transfers)
793
794    self.transfers = L
795    for i, xf in enumerate(L):
796      xf.order = i
797
798  def RemoveBackwardEdges(self):
799    print("Removing backward edges...")
800    in_order = 0
801    out_of_order = 0
802    lost_source = 0
803
804    for xf in self.transfers:
805      lost = 0
806      size = xf.src_ranges.size()
807      for u in xf.goes_before:
808        # xf should go before u
809        if xf.order < u.order:
810          # it does, hurray!
811          in_order += 1
812        else:
813          # it doesn't, boo.  trim the blocks that u writes from xf's
814          # source, so that xf can go after u.
815          out_of_order += 1
816          assert xf.src_ranges.overlaps(u.tgt_ranges)
817          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
818          xf.intact = False
819
820      if xf.style == "diff" and not xf.src_ranges:
821        # nothing left to diff from; treat as new data
822        xf.style = "new"
823
824      lost = size - xf.src_ranges.size()
825      lost_source += lost
826
827    print(("  %d/%d dependencies (%.2f%%) were violated; "
828           "%d source blocks removed.") %
829          (out_of_order, in_order + out_of_order,
830           (out_of_order * 100.0 / (in_order + out_of_order))
831           if (in_order + out_of_order) else 0.0,
832           lost_source))
833
834  def ReverseBackwardEdges(self):
835    print("Reversing backward edges...")
836    in_order = 0
837    out_of_order = 0
838    stashes = 0
839    stash_size = 0
840
841    for xf in self.transfers:
842      for u in xf.goes_before.copy():
843        # xf should go before u
844        if xf.order < u.order:
845          # it does, hurray!
846          in_order += 1
847        else:
848          # it doesn't, boo.  modify u to stash the blocks that it
849          # writes that xf wants to read, and then require u to go
850          # before xf.
851          out_of_order += 1
852
853          overlap = xf.src_ranges.intersect(u.tgt_ranges)
854          assert overlap
855
856          u.stash_before.append((stashes, overlap))
857          xf.use_stash.append((stashes, overlap))
858          stashes += 1
859          stash_size += overlap.size()
860
861          # reverse the edge direction; now xf must go after u
862          del xf.goes_before[u]
863          del u.goes_after[xf]
864          xf.goes_after[u] = None    # value doesn't matter
865          u.goes_before[xf] = None
866
867    print(("  %d/%d dependencies (%.2f%%) were violated; "
868           "%d source blocks stashed.") %
869          (out_of_order, in_order + out_of_order,
870           (out_of_order * 100.0 / (in_order + out_of_order))
871           if (in_order + out_of_order) else 0.0,
872           stash_size))
873
874  def FindVertexSequence(self):
875    print("Finding vertex sequence...")
876
877    # This is based on "A Fast & Effective Heuristic for the Feedback
878    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
879    # it as starting with the digraph G and moving all the vertices to
880    # be on a horizontal line in some order, trying to minimize the
881    # number of edges that end up pointing to the left.  Left-pointing
882    # edges will get removed to turn the digraph into a DAG.  In this
883    # case each edge has a weight which is the number of source blocks
884    # we'll lose if that edge is removed; we try to minimize the total
885    # weight rather than just the number of edges.
886
887    # Make a copy of the edge set; this copy will get destroyed by the
888    # algorithm.
889    for xf in self.transfers:
890      xf.incoming = xf.goes_after.copy()
891      xf.outgoing = xf.goes_before.copy()
892
893    # We use an OrderedDict instead of just a set so that the output
894    # is repeatable; otherwise it would depend on the hash values of
895    # the transfer objects.
896    G = OrderedDict()
897    for xf in self.transfers:
898      G[xf] = None
899    s1 = deque()  # the left side of the sequence, built from left to right
900    s2 = deque()  # the right side of the sequence, built from right to left
901
902    while G:
903
904      # Put all sinks at the end of the sequence.
905      while True:
906        sinks = [u for u in G if not u.outgoing]
907        if not sinks:
908          break
909        for u in sinks:
910          s2.appendleft(u)
911          del G[u]
912          for iu in u.incoming:
913            del iu.outgoing[u]
914
915      # Put all the sources at the beginning of the sequence.
916      while True:
917        sources = [u for u in G if not u.incoming]
918        if not sources:
919          break
920        for u in sources:
921          s1.append(u)
922          del G[u]
923          for iu in u.outgoing:
924            del iu.incoming[u]
925
926      if not G:
927        break
928
929      # Find the "best" vertex to put next.  "Best" is the one that
930      # maximizes the net difference in source blocks saved we get by
931      # pretending it's a source rather than a sink.
932
933      max_d = None
934      best_u = None
935      for u in G:
936        d = sum(u.outgoing.values()) - sum(u.incoming.values())
937        if best_u is None or d > max_d:
938          max_d = d
939          best_u = u
940
941      u = best_u
942      s1.append(u)
943      del G[u]
944      for iu in u.outgoing:
945        del iu.incoming[u]
946      for iu in u.incoming:
947        del iu.outgoing[u]
948
949    # Now record the sequence in the 'order' field of each transfer,
950    # and by rearranging self.transfers to be in the chosen sequence.
951
952    new_transfers = []
953    for x in itertools.chain(s1, s2):
954      x.order = len(new_transfers)
955      new_transfers.append(x)
956      del x.incoming
957      del x.outgoing
958
959    self.transfers = new_transfers
960
961  def GenerateDigraph(self):
962    print("Generating digraph...")
963    for a in self.transfers:
964      for b in self.transfers:
965        if a is b:
966          continue
967
968        # If the blocks written by A are read by B, then B needs to go before A.
969        i = a.tgt_ranges.intersect(b.src_ranges)
970        if i:
971          if b.src_name == "__ZERO":
972            # the cost of removing source blocks for the __ZERO domain
973            # is (nearly) zero.
974            size = 0
975          else:
976            size = i.size()
977          b.goes_before[a] = size
978          a.goes_after[b] = size
979
980  def FindTransfers(self):
981    """Parse the file_map to generate all the transfers."""
982
983    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
984                    split=False):
985      """Wrapper function for adding a Transfer().
986
987      For BBOTA v3, we need to stash source blocks for resumable feature.
988      However, with the growth of file size and the shrink of the cache
989      partition source blocks are too large to be stashed. If a file occupies
990      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
991      into smaller pieces by getting multiple Transfer()s.
992
993      The downside is that after splitting, we may increase the package size
994      since the split pieces don't align well. According to our experiments,
995      1/8 of the cache size as the per-piece limit appears to be optimal.
996      Compared to the fixed 1024-block limit, it reduces the overall package
997      size by 30% volantis, and 20% for angler and bullhead."""
998
999      # We care about diff transfers only.
1000      if style != "diff" or not split:
1001        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1002        return
1003
1004      pieces = 0
1005      cache_size = common.OPTIONS.cache_size
1006      split_threshold = 0.125
1007      max_blocks_per_transfer = int(cache_size * split_threshold /
1008                                    self.tgt.blocksize)
1009
1010      # Change nothing for small files.
1011      if (tgt_ranges.size() <= max_blocks_per_transfer and
1012          src_ranges.size() <= max_blocks_per_transfer):
1013        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1014        return
1015
1016      while (tgt_ranges.size() > max_blocks_per_transfer and
1017             src_ranges.size() > max_blocks_per_transfer):
1018        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1019        src_split_name = "%s-%d" % (src_name, pieces)
1020        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1021        src_first = src_ranges.first(max_blocks_per_transfer)
1022
1023        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
1024                 by_id)
1025
1026        tgt_ranges = tgt_ranges.subtract(tgt_first)
1027        src_ranges = src_ranges.subtract(src_first)
1028        pieces += 1
1029
1030      # Handle remaining blocks.
1031      if tgt_ranges.size() or src_ranges.size():
1032        # Must be both non-empty.
1033        assert tgt_ranges.size() and src_ranges.size()
1034        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1035        src_split_name = "%s-%d" % (src_name, pieces)
1036        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1037                 by_id)
1038
1039    empty = RangeSet()
1040    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1041      if tgt_fn == "__ZERO":
1042        # the special "__ZERO" domain is all the blocks not contained
1043        # in any file and that are filled with zeros.  We have a
1044        # special transfer style for zero blocks.
1045        src_ranges = self.src.file_map.get("__ZERO", empty)
1046        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1047                    "zero", self.transfers)
1048        continue
1049
1050      elif tgt_fn == "__COPY":
1051        # "__COPY" domain includes all the blocks not contained in any
1052        # file and that need to be copied unconditionally to the target.
1053        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1054        continue
1055
1056      elif tgt_fn in self.src.file_map:
1057        # Look for an exact pathname match in the source.
1058        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1059                    "diff", self.transfers, self.version >= 3)
1060        continue
1061
1062      b = os.path.basename(tgt_fn)
1063      if b in self.src_basenames:
1064        # Look for an exact basename match in the source.
1065        src_fn = self.src_basenames[b]
1066        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1067                    "diff", self.transfers, self.version >= 3)
1068        continue
1069
1070      b = re.sub("[0-9]+", "#", b)
1071      if b in self.src_numpatterns:
1072        # Look for a 'number pattern' match (a basename match after
1073        # all runs of digits are replaced by "#").  (This is useful
1074        # for .so files that contain version numbers in the filename
1075        # that get bumped.)
1076        src_fn = self.src_numpatterns[b]
1077        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1078                    "diff", self.transfers, self.version >= 3)
1079        continue
1080
1081      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1082
1083  def AbbreviateSourceNames(self):
1084    for k in self.src.file_map.keys():
1085      b = os.path.basename(k)
1086      self.src_basenames[b] = k
1087      b = re.sub("[0-9]+", "#", b)
1088      self.src_numpatterns[b] = k
1089
1090  @staticmethod
1091  def AssertPartition(total, seq):
1092    """Assert that all the RangeSets in 'seq' form a partition of the
1093    'total' RangeSet (ie, they are nonintersecting and their union
1094    equals 'total')."""
1095    so_far = RangeSet()
1096    for i in seq:
1097      assert not so_far.overlaps(i)
1098      so_far = so_far.union(i)
1099    assert so_far == total
1100