blockimgdiff.py revision ff75c230860ff3e88c1e73cfcc62270d76af29bc
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import array
20import common
21import functools
22import heapq
23import itertools
24import multiprocessing
25import os
26import re
27import subprocess
28import threading
29import time
30import tempfile
31
32from rangelib import RangeSet
33
34
35__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
36
37
38def compute_patch(src, tgt, imgdiff=False):
39  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
40  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
41  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
42  os.close(patchfd)
43
44  try:
45    with os.fdopen(srcfd, "wb") as f_src:
46      for p in src:
47        f_src.write(p)
48
49    with os.fdopen(tgtfd, "wb") as f_tgt:
50      for p in tgt:
51        f_tgt.write(p)
52    try:
53      os.unlink(patchfile)
54    except OSError:
55      pass
56    if imgdiff:
57      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
58                          stdout=open("/dev/null", "a"),
59                          stderr=subprocess.STDOUT)
60    else:
61      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
62
63    if p:
64      raise ValueError("diff failed: " + str(p))
65
66    with open(patchfile, "rb") as f:
67      return f.read()
68  finally:
69    try:
70      os.unlink(srcfile)
71      os.unlink(tgtfile)
72      os.unlink(patchfile)
73    except OSError:
74      pass
75
76
77class Image(object):
78  def ReadRangeSet(self, ranges):
79    raise NotImplementedError
80
81  def TotalSha1(self, include_clobbered_blocks=False):
82    raise NotImplementedError
83
84
85class EmptyImage(Image):
86  """A zero-length image."""
87  blocksize = 4096
88  care_map = RangeSet()
89  clobbered_blocks = RangeSet()
90  extended = RangeSet()
91  total_blocks = 0
92  file_map = {}
93  def ReadRangeSet(self, ranges):
94    return ()
95  def TotalSha1(self, include_clobbered_blocks=False):
96    # EmptyImage always carries empty clobbered_blocks, so
97    # include_clobbered_blocks can be ignored.
98    assert self.clobbered_blocks.size() == 0
99    return sha1().hexdigest()
100
101
102class DataImage(Image):
103  """An image wrapped around a single string of data."""
104
105  def __init__(self, data, trim=False, pad=False):
106    self.data = data
107    self.blocksize = 4096
108
109    assert not (trim and pad)
110
111    partial = len(self.data) % self.blocksize
112    padded = False
113    if partial > 0:
114      if trim:
115        self.data = self.data[:-partial]
116      elif pad:
117        self.data += '\0' * (self.blocksize - partial)
118        padded = True
119      else:
120        raise ValueError(("data for DataImage must be multiple of %d bytes "
121                          "unless trim or pad is specified") %
122                         (self.blocksize,))
123
124    assert len(self.data) % self.blocksize == 0
125
126    self.total_blocks = len(self.data) / self.blocksize
127    self.care_map = RangeSet(data=(0, self.total_blocks))
128    # When the last block is padded, we always write the whole block even for
129    # incremental OTAs. Because otherwise the last block may get skipped if
130    # unchanged for an incremental, but would fail the post-install
131    # verification if it has non-zero contents in the padding bytes.
132    # Bug: 23828506
133    if padded:
134      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
135    else:
136      clobbered_blocks = []
137    self.clobbered_blocks = clobbered_blocks
138    self.extended = RangeSet()
139
140    zero_blocks = []
141    nonzero_blocks = []
142    reference = '\0' * self.blocksize
143
144    for i in range(self.total_blocks-1 if padded else self.total_blocks):
145      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
146      if d == reference:
147        zero_blocks.append(i)
148        zero_blocks.append(i+1)
149      else:
150        nonzero_blocks.append(i)
151        nonzero_blocks.append(i+1)
152
153    assert zero_blocks or nonzero_blocks or clobbered_blocks
154
155    self.file_map = dict()
156    if zero_blocks:
157      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
158    if nonzero_blocks:
159      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
160    if clobbered_blocks:
161      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
162
163  def ReadRangeSet(self, ranges):
164    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
165
166  def TotalSha1(self, include_clobbered_blocks=False):
167    if not include_clobbered_blocks:
168      ranges = self.care_map.subtract(self.clobbered_blocks)
169      return sha1(self.ReadRangeSet(ranges)).hexdigest()
170    else:
171      return sha1(self.data).hexdigest()
172
173
174class Transfer(object):
175  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
176    self.tgt_name = tgt_name
177    self.src_name = src_name
178    self.tgt_ranges = tgt_ranges
179    self.src_ranges = src_ranges
180    self.style = style
181    self.intact = (getattr(tgt_ranges, "monotonic", False) and
182                   getattr(src_ranges, "monotonic", False))
183
184    # We use OrderedDict rather than dict so that the output is repeatable;
185    # otherwise it would depend on the hash values of the Transfer objects.
186    self.goes_before = OrderedDict()
187    self.goes_after = OrderedDict()
188
189    self.stash_before = []
190    self.use_stash = []
191
192    self.id = len(by_id)
193    by_id.append(self)
194
195  def NetStashChange(self):
196    return (sum(sr.size() for (_, sr) in self.stash_before) -
197            sum(sr.size() for (_, sr) in self.use_stash))
198
199  def ConvertToNew(self):
200    assert self.style != "new"
201    self.use_stash = []
202    self.style = "new"
203    self.src_ranges = RangeSet()
204
205  def __str__(self):
206    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
207            " to " + str(self.tgt_ranges) + ">")
208
209
210@functools.total_ordering
211class HeapItem(object):
212  def __init__(self, item):
213    self.item = item
214    # Negate the score since python's heap is a min-heap and we want
215    # the maximum score.
216    self.score = -item.score
217  def clear(self):
218    self.item = None
219  def __bool__(self):
220    return self.item is None
221  def __eq__(self, other):
222    return self.score == other.score
223  def __le__(self, other):
224    return self.score <= other.score
225
226
227# BlockImageDiff works on two image objects.  An image object is
228# anything that provides the following attributes:
229#
230#    blocksize: the size in bytes of a block, currently must be 4096.
231#
232#    total_blocks: the total size of the partition/image, in blocks.
233#
234#    care_map: a RangeSet containing which blocks (in the range [0,
235#      total_blocks) we actually care about; i.e. which blocks contain
236#      data.
237#
238#    file_map: a dict that partitions the blocks contained in care_map
239#      into smaller domains that are useful for doing diffs on.
240#      (Typically a domain is a file, and the key in file_map is the
241#      pathname.)
242#
243#    clobbered_blocks: a RangeSet containing which blocks contain data
244#      but may be altered by the FS. They need to be excluded when
245#      verifying the partition integrity.
246#
247#    ReadRangeSet(): a function that takes a RangeSet and returns the
248#      data contained in the image blocks of that RangeSet.  The data
249#      is returned as a list or tuple of strings; concatenating the
250#      elements together should produce the requested data.
251#      Implementations are free to break up the data into list/tuple
252#      elements in any way that is convenient.
253#
254#    TotalSha1(): a function that returns (as a hex string) the SHA-1
255#      hash of all the data in the image (ie, all the blocks in the
256#      care_map minus clobbered_blocks, or including the clobbered
257#      blocks if include_clobbered_blocks is True).
258#
259# When creating a BlockImageDiff, the src image may be None, in which
260# case the list of transfers produced will never read from the
261# original image.
262
263class BlockImageDiff(object):
264  def __init__(self, tgt, src=None, threads=None, version=4):
265    if threads is None:
266      threads = multiprocessing.cpu_count() // 2
267      if threads == 0:
268        threads = 1
269    self.threads = threads
270    self.version = version
271    self.transfers = []
272    self.src_basenames = {}
273    self.src_numpatterns = {}
274    self._max_stashed_size = 0
275
276    assert version in (1, 2, 3, 4)
277
278    self.tgt = tgt
279    if src is None:
280      src = EmptyImage()
281    self.src = src
282
283    # The updater code that installs the patch always uses 4k blocks.
284    assert tgt.blocksize == 4096
285    assert src.blocksize == 4096
286
287    # The range sets in each filemap should comprise a partition of
288    # the care map.
289    self.AssertPartition(src.care_map, src.file_map.values())
290    self.AssertPartition(tgt.care_map, tgt.file_map.values())
291
292  @property
293  def max_stashed_size(self):
294    return self._max_stashed_size
295
296  def Compute(self, prefix):
297    # When looking for a source file to use as the diff input for a
298    # target file, we try:
299    #   1) an exact path match if available, otherwise
300    #   2) a exact basename match if available, otherwise
301    #   3) a basename match after all runs of digits are replaced by
302    #      "#" if available, otherwise
303    #   4) we have no source for this target.
304    self.AbbreviateSourceNames()
305    self.FindTransfers()
306
307    # Find the ordering dependencies among transfers (this is O(n^2)
308    # in the number of transfers).
309    self.GenerateDigraph()
310    # Find a sequence of transfers that satisfies as many ordering
311    # dependencies as possible (heuristically).
312    self.FindVertexSequence()
313    # Fix up the ordering dependencies that the sequence didn't
314    # satisfy.
315    if self.version == 1:
316      self.RemoveBackwardEdges()
317    else:
318      self.ReverseBackwardEdges()
319      self.ImproveVertexSequence()
320
321    # Ensure the runtime stash size is under the limit.
322    if self.version >= 2 and common.OPTIONS.cache_size is not None:
323      self.ReviseStashSize()
324
325    # Double-check our work.
326    self.AssertSequenceGood()
327
328    self.ComputePatches(prefix)
329    self.WriteTransfers(prefix)
330
331  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
332    data = source.ReadRangeSet(ranges)
333    ctx = sha1()
334
335    for p in data:
336      ctx.update(p)
337
338    return ctx.hexdigest()
339
340  def WriteTransfers(self, prefix):
341    out = []
342
343    total = 0
344
345    stashes = {}
346    stashed_blocks = 0
347    max_stashed_blocks = 0
348
349    free_stash_ids = []
350    next_stash_id = 0
351
352    for xf in self.transfers:
353
354      if self.version < 2:
355        assert not xf.stash_before
356        assert not xf.use_stash
357
358      for s, sr in xf.stash_before:
359        assert s not in stashes
360        if free_stash_ids:
361          sid = heapq.heappop(free_stash_ids)
362        else:
363          sid = next_stash_id
364          next_stash_id += 1
365        stashes[s] = sid
366        if self.version == 2:
367          stashed_blocks += sr.size()
368          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
369        else:
370          sh = self.HashBlocks(self.src, sr)
371          if sh in stashes:
372            stashes[sh] += 1
373          else:
374            stashes[sh] = 1
375            stashed_blocks += sr.size()
376            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
377
378      if stashed_blocks > max_stashed_blocks:
379        max_stashed_blocks = stashed_blocks
380
381      free_string = []
382      free_size = 0
383
384      if self.version == 1:
385        src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else ""
386      elif self.version >= 2:
387
388        #   <# blocks> <src ranges>
389        #     OR
390        #   <# blocks> <src ranges> <src locs> <stash refs...>
391        #     OR
392        #   <# blocks> - <stash refs...>
393
394        size = xf.src_ranges.size()
395        src_str = [str(size)]
396
397        unstashed_src_ranges = xf.src_ranges
398        mapped_stashes = []
399        for s, sr in xf.use_stash:
400          sid = stashes.pop(s)
401          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
402          sh = self.HashBlocks(self.src, sr)
403          sr = xf.src_ranges.map_within(sr)
404          mapped_stashes.append(sr)
405          if self.version == 2:
406            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
407            # A stash will be used only once. We need to free the stash
408            # immediately after the use, instead of waiting for the automatic
409            # clean-up at the end. Because otherwise it may take up extra space
410            # and lead to OTA failures.
411            # Bug: 23119955
412            free_string.append("free %d\n" % (sid,))
413            free_size += sr.size()
414          else:
415            assert sh in stashes
416            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
417            stashes[sh] -= 1
418            if stashes[sh] == 0:
419              free_size += sr.size()
420              free_string.append("free %s\n" % (sh))
421              stashes.pop(sh)
422          heapq.heappush(free_stash_ids, sid)
423
424        if unstashed_src_ranges:
425          src_str.insert(1, unstashed_src_ranges.to_string_raw())
426          if xf.use_stash:
427            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
428            src_str.insert(2, mapped_unstashed.to_string_raw())
429            mapped_stashes.append(mapped_unstashed)
430            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
431        else:
432          src_str.insert(1, "-")
433          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
434
435        src_str = " ".join(src_str)
436
437      # all versions:
438      #   zero <rangeset>
439      #   new <rangeset>
440      #   erase <rangeset>
441      #
442      # version 1:
443      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
444      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
445      #   move <src rangeset> <tgt rangeset>
446      #
447      # version 2:
448      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
449      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
450      #   move <tgt rangeset> <src_str>
451      #
452      # version 3:
453      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
454      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
455      #   move hash <tgt rangeset> <src_str>
456
457      tgt_size = xf.tgt_ranges.size()
458
459      if xf.style == "new":
460        assert xf.tgt_ranges
461        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
462        total += tgt_size
463      elif xf.style == "move":
464        assert xf.tgt_ranges
465        assert xf.src_ranges.size() == tgt_size
466        if xf.src_ranges != xf.tgt_ranges:
467          if self.version == 1:
468            out.append("%s %s %s\n" % (
469                xf.style,
470                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
471          elif self.version == 2:
472            out.append("%s %s %s\n" % (
473                xf.style,
474                xf.tgt_ranges.to_string_raw(), src_str))
475          elif self.version >= 3:
476            # take into account automatic stashing of overlapping blocks
477            if xf.src_ranges.overlaps(xf.tgt_ranges):
478              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
479              if temp_stash_usage > max_stashed_blocks:
480                max_stashed_blocks = temp_stash_usage
481
482            out.append("%s %s %s %s\n" % (
483                xf.style,
484                self.HashBlocks(self.tgt, xf.tgt_ranges),
485                xf.tgt_ranges.to_string_raw(), src_str))
486          total += tgt_size
487      elif xf.style in ("bsdiff", "imgdiff"):
488        assert xf.tgt_ranges
489        assert xf.src_ranges
490        if self.version == 1:
491          out.append("%s %d %d %s %s\n" % (
492              xf.style, xf.patch_start, xf.patch_len,
493              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
494        elif self.version == 2:
495          out.append("%s %d %d %s %s\n" % (
496              xf.style, xf.patch_start, xf.patch_len,
497              xf.tgt_ranges.to_string_raw(), src_str))
498        elif self.version >= 3:
499          # take into account automatic stashing of overlapping blocks
500          if xf.src_ranges.overlaps(xf.tgt_ranges):
501            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
502            if temp_stash_usage > max_stashed_blocks:
503              max_stashed_blocks = temp_stash_usage
504
505          out.append("%s %d %d %s %s %s %s\n" % (
506              xf.style,
507              xf.patch_start, xf.patch_len,
508              self.HashBlocks(self.src, xf.src_ranges),
509              self.HashBlocks(self.tgt, xf.tgt_ranges),
510              xf.tgt_ranges.to_string_raw(), src_str))
511        total += tgt_size
512      elif xf.style == "zero":
513        assert xf.tgt_ranges
514        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
515        if to_zero:
516          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
517          total += to_zero.size()
518      else:
519        raise ValueError("unknown transfer style '%s'\n" % xf.style)
520
521      if free_string:
522        out.append("".join(free_string))
523        stashed_blocks -= free_size
524
525      if self.version >= 2 and common.OPTIONS.cache_size is not None:
526        # Sanity check: abort if we're going to need more stash space than
527        # the allowed size (cache_size * threshold). There are two purposes
528        # of having a threshold here. a) Part of the cache may have been
529        # occupied by some recovery logs. b) It will buy us some time to deal
530        # with the oversize issue.
531        cache_size = common.OPTIONS.cache_size
532        stash_threshold = common.OPTIONS.stash_threshold
533        max_allowed = cache_size * stash_threshold
534        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
535               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
536                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
537                   self.tgt.blocksize, max_allowed, cache_size,
538                   stash_threshold)
539
540    # Zero out extended blocks as a workaround for bug 20881595.
541    if self.tgt.extended:
542      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
543      total += self.tgt.extended.size()
544
545    # We erase all the blocks on the partition that a) don't contain useful
546    # data in the new image and b) will not be touched by dm-verity.
547    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
548    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
549    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
550    if new_dontcare:
551      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
552
553    out.insert(0, "%d\n" % (self.version,))   # format version number
554    out.insert(1, "%d\n" % (total,))
555    if self.version >= 2:
556      # version 2 only: after the total block count, we give the number
557      # of stash slots needed, and the maximum size needed (in blocks)
558      out.insert(2, str(next_stash_id) + "\n")
559      out.insert(3, str(max_stashed_blocks) + "\n")
560
561    with open(prefix + ".transfer.list", "wb") as f:
562      for i in out:
563        f.write(i)
564
565    if self.version >= 2:
566      self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
567      OPTIONS = common.OPTIONS
568      if OPTIONS.cache_size is not None:
569        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
570        print("max stashed blocks: %d  (%d bytes), "
571              "limit: %d bytes (%.2f%%)\n" % (
572              max_stashed_blocks, self._max_stashed_size, max_allowed,
573              self._max_stashed_size * 100.0 / max_allowed))
574      else:
575        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
576              max_stashed_blocks, self._max_stashed_size))
577
578  def ReviseStashSize(self):
579    print("Revising stash size...")
580    stashes = {}
581
582    # Create the map between a stash and its def/use points. For example, for a
583    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
584    for xf in self.transfers:
585      # Command xf defines (stores) all the stashes in stash_before.
586      for idx, sr in xf.stash_before:
587        stashes[idx] = (sr, xf)
588
589      # Record all the stashes command xf uses.
590      for idx, _ in xf.use_stash:
591        stashes[idx] += (xf,)
592
593    # Compute the maximum blocks available for stash based on /cache size and
594    # the threshold.
595    cache_size = common.OPTIONS.cache_size
596    stash_threshold = common.OPTIONS.stash_threshold
597    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
598
599    stashed_blocks = 0
600    new_blocks = 0
601
602    # Now go through all the commands. Compute the required stash size on the
603    # fly. If a command requires excess stash than available, it deletes the
604    # stash by replacing the command that uses the stash with a "new" command
605    # instead.
606    for xf in self.transfers:
607      replaced_cmds = []
608
609      # xf.stash_before generates explicit stash commands.
610      for idx, sr in xf.stash_before:
611        if stashed_blocks + sr.size() > max_allowed:
612          # We cannot stash this one for a later command. Find out the command
613          # that will use this stash and replace the command with "new".
614          use_cmd = stashes[idx][2]
615          replaced_cmds.append(use_cmd)
616          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
617        else:
618          stashed_blocks += sr.size()
619
620      # xf.use_stash generates free commands.
621      for _, sr in xf.use_stash:
622        stashed_blocks -= sr.size()
623
624      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
625      # ComputePatches(), they both have the style of "diff".
626      if xf.style == "diff" and self.version >= 3:
627        assert xf.tgt_ranges and xf.src_ranges
628        if xf.src_ranges.overlaps(xf.tgt_ranges):
629          if stashed_blocks + xf.src_ranges.size() > max_allowed:
630            replaced_cmds.append(xf)
631            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
632
633      # Replace the commands in replaced_cmds with "new"s.
634      for cmd in replaced_cmds:
635        # It no longer uses any commands in "use_stash". Remove the def points
636        # for all those stashes.
637        for idx, sr in cmd.use_stash:
638          def_cmd = stashes[idx][1]
639          assert (idx, sr) in def_cmd.stash_before
640          def_cmd.stash_before.remove((idx, sr))
641
642        # Add up blocks that violates space limit and print total number to
643        # screen later.
644        new_blocks += cmd.tgt_ranges.size()
645        cmd.ConvertToNew()
646
647    num_of_bytes = new_blocks * self.tgt.blocksize
648    print("  Total %d blocks (%d bytes) are packed as new blocks due to "
649          "insufficient cache size." % (new_blocks, num_of_bytes))
650
651  def ComputePatches(self, prefix):
652    print("Reticulating splines...")
653    diff_q = []
654    patch_num = 0
655    with open(prefix + ".new.dat", "wb") as new_f:
656      for xf in self.transfers:
657        if xf.style == "zero":
658          pass
659        elif xf.style == "new":
660          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
661            new_f.write(piece)
662        elif xf.style == "diff":
663          src = self.src.ReadRangeSet(xf.src_ranges)
664          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
665
666          # We can't compare src and tgt directly because they may have
667          # the same content but be broken up into blocks differently, eg:
668          #
669          #    ["he", "llo"]  vs  ["h", "ello"]
670          #
671          # We want those to compare equal, ideally without having to
672          # actually concatenate the strings (these may be tens of
673          # megabytes).
674
675          src_sha1 = sha1()
676          for p in src:
677            src_sha1.update(p)
678          tgt_sha1 = sha1()
679          tgt_size = 0
680          for p in tgt:
681            tgt_sha1.update(p)
682            tgt_size += len(p)
683
684          if src_sha1.digest() == tgt_sha1.digest():
685            # These are identical; we don't need to generate a patch,
686            # just issue copy commands on the device.
687            xf.style = "move"
688          else:
689            # For files in zip format (eg, APKs, JARs, etc.) we would
690            # like to use imgdiff -z if possible (because it usually
691            # produces significantly smaller patches than bsdiff).
692            # This is permissible if:
693            #
694            #  - the source and target files are monotonic (ie, the
695            #    data is stored with blocks in increasing order), and
696            #  - we haven't removed any blocks from the source set.
697            #
698            # If these conditions are satisfied then appending all the
699            # blocks in the set together in order will produce a valid
700            # zip file (plus possibly extra zeros in the last block),
701            # which is what imgdiff needs to operate.  (imgdiff is
702            # fine with extra zeros at the end of the file.)
703            imgdiff = (xf.intact and
704                       xf.tgt_name.split(".")[-1].lower()
705                       in ("apk", "jar", "zip"))
706            xf.style = "imgdiff" if imgdiff else "bsdiff"
707            diff_q.append((tgt_size, src, tgt, xf, patch_num))
708            patch_num += 1
709
710        else:
711          assert False, "unknown style " + xf.style
712
713    if diff_q:
714      if self.threads > 1:
715        print("Computing patches (using %d threads)..." % (self.threads,))
716      else:
717        print("Computing patches...")
718      diff_q.sort()
719
720      patches = [None] * patch_num
721
722      # TODO: Rewrite with multiprocessing.ThreadPool?
723      lock = threading.Lock()
724      def diff_worker():
725        while True:
726          with lock:
727            if not diff_q:
728              return
729            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
730          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
731          size = len(patch)
732          with lock:
733            patches[patchnum] = (patch, xf)
734            print("%10d %10d (%6.2f%%) %7s %s" % (
735                size, tgt_size, size * 100.0 / tgt_size, xf.style,
736                xf.tgt_name if xf.tgt_name == xf.src_name else (
737                    xf.tgt_name + " (from " + xf.src_name + ")")))
738
739      threads = [threading.Thread(target=diff_worker)
740                 for _ in range(self.threads)]
741      for th in threads:
742        th.start()
743      while threads:
744        threads.pop().join()
745    else:
746      patches = []
747
748    p = 0
749    with open(prefix + ".patch.dat", "wb") as patch_f:
750      for patch, xf in patches:
751        xf.patch_start = p
752        xf.patch_len = len(patch)
753        patch_f.write(patch)
754        p += len(patch)
755
756  def AssertSequenceGood(self):
757    # Simulate the sequences of transfers we will output, and check that:
758    # - we never read a block after writing it, and
759    # - we write every block we care about exactly once.
760
761    # Start with no blocks having been touched yet.
762    touched = array.array("B", "\0" * self.tgt.total_blocks)
763
764    # Imagine processing the transfers in order.
765    for xf in self.transfers:
766      # Check that the input blocks for this transfer haven't yet been touched.
767
768      x = xf.src_ranges
769      if self.version >= 2:
770        for _, sr in xf.use_stash:
771          x = x.subtract(sr)
772
773      for s, e in x:
774        # Source image could be larger. Don't check the blocks that are in the
775        # source image only. Since they are not in 'touched', and won't ever
776        # be touched.
777        for i in range(s, min(e, self.tgt.total_blocks)):
778          assert touched[i] == 0
779
780      # Check that the output blocks for this transfer haven't yet
781      # been touched, and touch all the blocks written by this
782      # transfer.
783      for s, e in xf.tgt_ranges:
784        for i in range(s, e):
785          assert touched[i] == 0
786          touched[i] = 1
787
788    # Check that we've written every target block.
789    for s, e in self.tgt.care_map:
790      for i in range(s, e):
791        assert touched[i] == 1
792
793  def ImproveVertexSequence(self):
794    print("Improving vertex order...")
795
796    # At this point our digraph is acyclic; we reversed any edges that
797    # were backwards in the heuristically-generated sequence.  The
798    # previously-generated order is still acceptable, but we hope to
799    # find a better order that needs less memory for stashed data.
800    # Now we do a topological sort to generate a new vertex order,
801    # using a greedy algorithm to choose which vertex goes next
802    # whenever we have a choice.
803
804    # Make a copy of the edge set; this copy will get destroyed by the
805    # algorithm.
806    for xf in self.transfers:
807      xf.incoming = xf.goes_after.copy()
808      xf.outgoing = xf.goes_before.copy()
809
810    L = []   # the new vertex order
811
812    # S is the set of sources in the remaining graph; we always choose
813    # the one that leaves the least amount of stashed data after it's
814    # executed.
815    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
816         if not u.incoming]
817    heapq.heapify(S)
818
819    while S:
820      _, _, xf = heapq.heappop(S)
821      L.append(xf)
822      for u in xf.outgoing:
823        del u.incoming[xf]
824        if not u.incoming:
825          heapq.heappush(S, (u.NetStashChange(), u.order, u))
826
827    # if this fails then our graph had a cycle.
828    assert len(L) == len(self.transfers)
829
830    self.transfers = L
831    for i, xf in enumerate(L):
832      xf.order = i
833
834  def RemoveBackwardEdges(self):
835    print("Removing backward edges...")
836    in_order = 0
837    out_of_order = 0
838    lost_source = 0
839
840    for xf in self.transfers:
841      lost = 0
842      size = xf.src_ranges.size()
843      for u in xf.goes_before:
844        # xf should go before u
845        if xf.order < u.order:
846          # it does, hurray!
847          in_order += 1
848        else:
849          # it doesn't, boo.  trim the blocks that u writes from xf's
850          # source, so that xf can go after u.
851          out_of_order += 1
852          assert xf.src_ranges.overlaps(u.tgt_ranges)
853          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
854          xf.intact = False
855
856      if xf.style == "diff" and not xf.src_ranges:
857        # nothing left to diff from; treat as new data
858        xf.style = "new"
859
860      lost = size - xf.src_ranges.size()
861      lost_source += lost
862
863    print(("  %d/%d dependencies (%.2f%%) were violated; "
864           "%d source blocks removed.") %
865          (out_of_order, in_order + out_of_order,
866           (out_of_order * 100.0 / (in_order + out_of_order))
867           if (in_order + out_of_order) else 0.0,
868           lost_source))
869
870  def ReverseBackwardEdges(self):
871    print("Reversing backward edges...")
872    in_order = 0
873    out_of_order = 0
874    stashes = 0
875    stash_size = 0
876
877    for xf in self.transfers:
878      for u in xf.goes_before.copy():
879        # xf should go before u
880        if xf.order < u.order:
881          # it does, hurray!
882          in_order += 1
883        else:
884          # it doesn't, boo.  modify u to stash the blocks that it
885          # writes that xf wants to read, and then require u to go
886          # before xf.
887          out_of_order += 1
888
889          overlap = xf.src_ranges.intersect(u.tgt_ranges)
890          assert overlap
891
892          u.stash_before.append((stashes, overlap))
893          xf.use_stash.append((stashes, overlap))
894          stashes += 1
895          stash_size += overlap.size()
896
897          # reverse the edge direction; now xf must go after u
898          del xf.goes_before[u]
899          del u.goes_after[xf]
900          xf.goes_after[u] = None    # value doesn't matter
901          u.goes_before[xf] = None
902
903    print(("  %d/%d dependencies (%.2f%%) were violated; "
904           "%d source blocks stashed.") %
905          (out_of_order, in_order + out_of_order,
906           (out_of_order * 100.0 / (in_order + out_of_order))
907           if (in_order + out_of_order) else 0.0,
908           stash_size))
909
910  def FindVertexSequence(self):
911    print("Finding vertex sequence...")
912
913    # This is based on "A Fast & Effective Heuristic for the Feedback
914    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
915    # it as starting with the digraph G and moving all the vertices to
916    # be on a horizontal line in some order, trying to minimize the
917    # number of edges that end up pointing to the left.  Left-pointing
918    # edges will get removed to turn the digraph into a DAG.  In this
919    # case each edge has a weight which is the number of source blocks
920    # we'll lose if that edge is removed; we try to minimize the total
921    # weight rather than just the number of edges.
922
923    # Make a copy of the edge set; this copy will get destroyed by the
924    # algorithm.
925    for xf in self.transfers:
926      xf.incoming = xf.goes_after.copy()
927      xf.outgoing = xf.goes_before.copy()
928      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
929
930    # We use an OrderedDict instead of just a set so that the output
931    # is repeatable; otherwise it would depend on the hash values of
932    # the transfer objects.
933    G = OrderedDict()
934    for xf in self.transfers:
935      G[xf] = None
936    s1 = deque()  # the left side of the sequence, built from left to right
937    s2 = deque()  # the right side of the sequence, built from right to left
938
939    heap = []
940    for xf in self.transfers:
941      xf.heap_item = HeapItem(xf)
942      heap.append(xf.heap_item)
943    heapq.heapify(heap)
944
945    sinks = set(u for u in G if not u.outgoing)
946    sources = set(u for u in G if not u.incoming)
947
948    def adjust_score(iu, delta):
949      iu.score += delta
950      iu.heap_item.clear()
951      iu.heap_item = HeapItem(iu)
952      heapq.heappush(heap, iu.heap_item)
953
954    while G:
955      # Put all sinks at the end of the sequence.
956      while sinks:
957        new_sinks = set()
958        for u in sinks:
959          if u not in G: continue
960          s2.appendleft(u)
961          del G[u]
962          for iu in u.incoming:
963            adjust_score(iu, -iu.outgoing.pop(u))
964            if not iu.outgoing: new_sinks.add(iu)
965        sinks = new_sinks
966
967      # Put all the sources at the beginning of the sequence.
968      while sources:
969        new_sources = set()
970        for u in sources:
971          if u not in G: continue
972          s1.append(u)
973          del G[u]
974          for iu in u.outgoing:
975            adjust_score(iu, +iu.incoming.pop(u))
976            if not iu.incoming: new_sources.add(iu)
977        sources = new_sources
978
979      if not G: break
980
981      # Find the "best" vertex to put next.  "Best" is the one that
982      # maximizes the net difference in source blocks saved we get by
983      # pretending it's a source rather than a sink.
984
985      while True:
986        u = heapq.heappop(heap)
987        if u and u.item in G:
988          u = u.item
989          break
990
991      s1.append(u)
992      del G[u]
993      for iu in u.outgoing:
994        adjust_score(iu, +iu.incoming.pop(u))
995        if not iu.incoming: sources.add(iu)
996
997      for iu in u.incoming:
998        adjust_score(iu, -iu.outgoing.pop(u))
999        if not iu.outgoing: sinks.add(iu)
1000
1001    # Now record the sequence in the 'order' field of each transfer,
1002    # and by rearranging self.transfers to be in the chosen sequence.
1003
1004    new_transfers = []
1005    for x in itertools.chain(s1, s2):
1006      x.order = len(new_transfers)
1007      new_transfers.append(x)
1008      del x.incoming
1009      del x.outgoing
1010
1011    self.transfers = new_transfers
1012
1013  def GenerateDigraph(self):
1014    print("Generating digraph...")
1015
1016    # Each item of source_ranges will be:
1017    #   - None, if that block is not used as a source,
1018    #   - a transfer, if one transfer uses it as a source, or
1019    #   - a set of transfers.
1020    source_ranges = []
1021    for b in self.transfers:
1022      for s, e in b.src_ranges:
1023        if e > len(source_ranges):
1024          source_ranges.extend([None] * (e-len(source_ranges)))
1025        for i in range(s, e):
1026          if source_ranges[i] is None:
1027            source_ranges[i] = b
1028          else:
1029            if not isinstance(source_ranges[i], set):
1030              source_ranges[i] = set([source_ranges[i]])
1031            source_ranges[i].add(b)
1032
1033    for a in self.transfers:
1034      intersections = set()
1035      for s, e in a.tgt_ranges:
1036        for i in range(s, e):
1037          if i >= len(source_ranges): break
1038          b = source_ranges[i]
1039          if b is not None:
1040            if isinstance(b, set):
1041              intersections.update(b)
1042            else:
1043              intersections.add(b)
1044
1045      for b in intersections:
1046        if a is b: continue
1047
1048        # If the blocks written by A are read by B, then B needs to go before A.
1049        i = a.tgt_ranges.intersect(b.src_ranges)
1050        if i:
1051          if b.src_name == "__ZERO":
1052            # the cost of removing source blocks for the __ZERO domain
1053            # is (nearly) zero.
1054            size = 0
1055          else:
1056            size = i.size()
1057          b.goes_before[a] = size
1058          a.goes_after[b] = size
1059
1060  def FindTransfers(self):
1061    """Parse the file_map to generate all the transfers."""
1062
1063    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1064                    split=False):
1065      """Wrapper function for adding a Transfer().
1066
1067      For BBOTA v3, we need to stash source blocks for resumable feature.
1068      However, with the growth of file size and the shrink of the cache
1069      partition source blocks are too large to be stashed. If a file occupies
1070      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
1071      into smaller pieces by getting multiple Transfer()s.
1072
1073      The downside is that after splitting, we may increase the package size
1074      since the split pieces don't align well. According to our experiments,
1075      1/8 of the cache size as the per-piece limit appears to be optimal.
1076      Compared to the fixed 1024-block limit, it reduces the overall package
1077      size by 30% volantis, and 20% for angler and bullhead."""
1078
1079      # We care about diff transfers only.
1080      if style != "diff" or not split:
1081        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1082        return
1083
1084      pieces = 0
1085      cache_size = common.OPTIONS.cache_size
1086      split_threshold = 0.125
1087      max_blocks_per_transfer = int(cache_size * split_threshold /
1088                                    self.tgt.blocksize)
1089
1090      # Change nothing for small files.
1091      if (tgt_ranges.size() <= max_blocks_per_transfer and
1092          src_ranges.size() <= max_blocks_per_transfer):
1093        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1094        return
1095
1096      while (tgt_ranges.size() > max_blocks_per_transfer and
1097             src_ranges.size() > max_blocks_per_transfer):
1098        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1099        src_split_name = "%s-%d" % (src_name, pieces)
1100        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1101        src_first = src_ranges.first(max_blocks_per_transfer)
1102
1103        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
1104                 by_id)
1105
1106        tgt_ranges = tgt_ranges.subtract(tgt_first)
1107        src_ranges = src_ranges.subtract(src_first)
1108        pieces += 1
1109
1110      # Handle remaining blocks.
1111      if tgt_ranges.size() or src_ranges.size():
1112        # Must be both non-empty.
1113        assert tgt_ranges.size() and src_ranges.size()
1114        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1115        src_split_name = "%s-%d" % (src_name, pieces)
1116        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1117                 by_id)
1118
1119    empty = RangeSet()
1120    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1121      if tgt_fn == "__ZERO":
1122        # the special "__ZERO" domain is all the blocks not contained
1123        # in any file and that are filled with zeros.  We have a
1124        # special transfer style for zero blocks.
1125        src_ranges = self.src.file_map.get("__ZERO", empty)
1126        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1127                    "zero", self.transfers)
1128        continue
1129
1130      elif tgt_fn == "__COPY":
1131        # "__COPY" domain includes all the blocks not contained in any
1132        # file and that need to be copied unconditionally to the target.
1133        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1134        continue
1135
1136      elif tgt_fn in self.src.file_map:
1137        # Look for an exact pathname match in the source.
1138        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1139                    "diff", self.transfers, self.version >= 3)
1140        continue
1141
1142      b = os.path.basename(tgt_fn)
1143      if b in self.src_basenames:
1144        # Look for an exact basename match in the source.
1145        src_fn = self.src_basenames[b]
1146        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1147                    "diff", self.transfers, self.version >= 3)
1148        continue
1149
1150      b = re.sub("[0-9]+", "#", b)
1151      if b in self.src_numpatterns:
1152        # Look for a 'number pattern' match (a basename match after
1153        # all runs of digits are replaced by "#").  (This is useful
1154        # for .so files that contain version numbers in the filename
1155        # that get bumped.)
1156        src_fn = self.src_numpatterns[b]
1157        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1158                    "diff", self.transfers, self.version >= 3)
1159        continue
1160
1161      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1162
1163  def AbbreviateSourceNames(self):
1164    for k in self.src.file_map.keys():
1165      b = os.path.basename(k)
1166      self.src_basenames[b] = k
1167      b = re.sub("[0-9]+", "#", b)
1168      self.src_numpatterns[b] = k
1169
1170  @staticmethod
1171  def AssertPartition(total, seq):
1172    """Assert that all the RangeSets in 'seq' form a partition of the
1173    'total' RangeSet (ie, they are nonintersecting and their union
1174    equals 'total')."""
1175
1176    so_far = RangeSet()
1177    for i in seq:
1178      assert not so_far.overlaps(i)
1179      so_far = so_far.union(i)
1180    assert so_far == total
1181