blockimgdiff.py revision 66f1fa65658d33f9bdc42691738a5fb6559560cb
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import array
20import common
21import functools
22import heapq
23import itertools
24import multiprocessing
25import os
26import re
27import subprocess
28import threading
29import time
30import tempfile
31
32from rangelib import RangeSet
33
34
35__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
36
37
38def compute_patch(src, tgt, imgdiff=False):
39  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
40  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
41  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
42  os.close(patchfd)
43
44  try:
45    with os.fdopen(srcfd, "wb") as f_src:
46      for p in src:
47        f_src.write(p)
48
49    with os.fdopen(tgtfd, "wb") as f_tgt:
50      for p in tgt:
51        f_tgt.write(p)
52    try:
53      os.unlink(patchfile)
54    except OSError:
55      pass
56    if imgdiff:
57      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
58                          stdout=open("/dev/null", "a"),
59                          stderr=subprocess.STDOUT)
60    else:
61      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
62
63    if p:
64      raise ValueError("diff failed: " + str(p))
65
66    with open(patchfile, "rb") as f:
67      return f.read()
68  finally:
69    try:
70      os.unlink(srcfile)
71      os.unlink(tgtfile)
72      os.unlink(patchfile)
73    except OSError:
74      pass
75
76
77class Image(object):
78  def ReadRangeSet(self, ranges):
79    raise NotImplementedError
80
81  def TotalSha1(self, include_clobbered_blocks=False):
82    raise NotImplementedError
83
84
85class EmptyImage(Image):
86  """A zero-length image."""
87  blocksize = 4096
88  care_map = RangeSet()
89  clobbered_blocks = RangeSet()
90  extended = RangeSet()
91  total_blocks = 0
92  file_map = {}
93  def ReadRangeSet(self, ranges):
94    return ()
95  def TotalSha1(self, include_clobbered_blocks=False):
96    # EmptyImage always carries empty clobbered_blocks, so
97    # include_clobbered_blocks can be ignored.
98    assert self.clobbered_blocks.size() == 0
99    return sha1().hexdigest()
100
101
102class DataImage(Image):
103  """An image wrapped around a single string of data."""
104
105  def __init__(self, data, trim=False, pad=False):
106    self.data = data
107    self.blocksize = 4096
108
109    assert not (trim and pad)
110
111    partial = len(self.data) % self.blocksize
112    padded = False
113    if partial > 0:
114      if trim:
115        self.data = self.data[:-partial]
116      elif pad:
117        self.data += '\0' * (self.blocksize - partial)
118        padded = True
119      else:
120        raise ValueError(("data for DataImage must be multiple of %d bytes "
121                          "unless trim or pad is specified") %
122                         (self.blocksize,))
123
124    assert len(self.data) % self.blocksize == 0
125
126    self.total_blocks = len(self.data) / self.blocksize
127    self.care_map = RangeSet(data=(0, self.total_blocks))
128    # When the last block is padded, we always write the whole block even for
129    # incremental OTAs. Because otherwise the last block may get skipped if
130    # unchanged for an incremental, but would fail the post-install
131    # verification if it has non-zero contents in the padding bytes.
132    # Bug: 23828506
133    if padded:
134      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
135    else:
136      clobbered_blocks = []
137    self.clobbered_blocks = clobbered_blocks
138    self.extended = RangeSet()
139
140    zero_blocks = []
141    nonzero_blocks = []
142    reference = '\0' * self.blocksize
143
144    for i in range(self.total_blocks-1 if padded else self.total_blocks):
145      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
146      if d == reference:
147        zero_blocks.append(i)
148        zero_blocks.append(i+1)
149      else:
150        nonzero_blocks.append(i)
151        nonzero_blocks.append(i+1)
152
153    assert zero_blocks or nonzero_blocks or clobbered_blocks
154
155    self.file_map = dict()
156    if zero_blocks:
157      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
158    if nonzero_blocks:
159      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
160    if clobbered_blocks:
161      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
162
163  def ReadRangeSet(self, ranges):
164    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
165
166  def TotalSha1(self, include_clobbered_blocks=False):
167    if not include_clobbered_blocks:
168      ranges = self.care_map.subtract(self.clobbered_blocks)
169      return sha1(self.ReadRangeSet(ranges)).hexdigest()
170    else:
171      return sha1(self.data).hexdigest()
172
173
174class Transfer(object):
175  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
176    self.tgt_name = tgt_name
177    self.src_name = src_name
178    self.tgt_ranges = tgt_ranges
179    self.src_ranges = src_ranges
180    self.style = style
181    self.intact = (getattr(tgt_ranges, "monotonic", False) and
182                   getattr(src_ranges, "monotonic", False))
183
184    # We use OrderedDict rather than dict so that the output is repeatable;
185    # otherwise it would depend on the hash values of the Transfer objects.
186    self.goes_before = OrderedDict()
187    self.goes_after = OrderedDict()
188
189    self.stash_before = []
190    self.use_stash = []
191
192    self.id = len(by_id)
193    by_id.append(self)
194
195  def NetStashChange(self):
196    return (sum(sr.size() for (_, sr) in self.stash_before) -
197            sum(sr.size() for (_, sr) in self.use_stash))
198
199  def ConvertToNew(self):
200    assert self.style != "new"
201    self.use_stash = []
202    self.style = "new"
203    self.src_ranges = RangeSet()
204
205  def __str__(self):
206    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
207            " to " + str(self.tgt_ranges) + ">")
208
209
210@functools.total_ordering
211class HeapItem(object):
212  def __init__(self, item):
213    self.item = item
214    # Negate the score since python's heap is a min-heap and we want
215    # the maximum score.
216    self.score = -item.score
217  def clear(self):
218    self.item = None
219  def __bool__(self):
220    return self.item is None
221  def __eq__(self, other):
222    return self.score == other.score
223  def __le__(self, other):
224    return self.score <= other.score
225
226
227# BlockImageDiff works on two image objects.  An image object is
228# anything that provides the following attributes:
229#
230#    blocksize: the size in bytes of a block, currently must be 4096.
231#
232#    total_blocks: the total size of the partition/image, in blocks.
233#
234#    care_map: a RangeSet containing which blocks (in the range [0,
235#      total_blocks) we actually care about; i.e. which blocks contain
236#      data.
237#
238#    file_map: a dict that partitions the blocks contained in care_map
239#      into smaller domains that are useful for doing diffs on.
240#      (Typically a domain is a file, and the key in file_map is the
241#      pathname.)
242#
243#    clobbered_blocks: a RangeSet containing which blocks contain data
244#      but may be altered by the FS. They need to be excluded when
245#      verifying the partition integrity.
246#
247#    ReadRangeSet(): a function that takes a RangeSet and returns the
248#      data contained in the image blocks of that RangeSet.  The data
249#      is returned as a list or tuple of strings; concatenating the
250#      elements together should produce the requested data.
251#      Implementations are free to break up the data into list/tuple
252#      elements in any way that is convenient.
253#
254#    TotalSha1(): a function that returns (as a hex string) the SHA-1
255#      hash of all the data in the image (ie, all the blocks in the
256#      care_map minus clobbered_blocks, or including the clobbered
257#      blocks if include_clobbered_blocks is True).
258#
259# When creating a BlockImageDiff, the src image may be None, in which
260# case the list of transfers produced will never read from the
261# original image.
262
263class BlockImageDiff(object):
264  def __init__(self, tgt, src=None, threads=None, version=4):
265    if threads is None:
266      threads = multiprocessing.cpu_count() // 2
267      if threads == 0:
268        threads = 1
269    self.threads = threads
270    self.version = version
271    self.transfers = []
272    self.src_basenames = {}
273    self.src_numpatterns = {}
274    self._max_stashed_size = 0
275    self.touched_src_ranges = RangeSet()
276    self.touched_src_sha1 = None
277
278    assert version in (1, 2, 3, 4)
279
280    self.tgt = tgt
281    if src is None:
282      src = EmptyImage()
283    self.src = src
284
285    # The updater code that installs the patch always uses 4k blocks.
286    assert tgt.blocksize == 4096
287    assert src.blocksize == 4096
288
289    # The range sets in each filemap should comprise a partition of
290    # the care map.
291    self.AssertPartition(src.care_map, src.file_map.values())
292    self.AssertPartition(tgt.care_map, tgt.file_map.values())
293
294  @property
295  def max_stashed_size(self):
296    return self._max_stashed_size
297
298  def Compute(self, prefix):
299    # When looking for a source file to use as the diff input for a
300    # target file, we try:
301    #   1) an exact path match if available, otherwise
302    #   2) a exact basename match if available, otherwise
303    #   3) a basename match after all runs of digits are replaced by
304    #      "#" if available, otherwise
305    #   4) we have no source for this target.
306    self.AbbreviateSourceNames()
307    self.FindTransfers()
308
309    # Find the ordering dependencies among transfers (this is O(n^2)
310    # in the number of transfers).
311    self.GenerateDigraph()
312    # Find a sequence of transfers that satisfies as many ordering
313    # dependencies as possible (heuristically).
314    self.FindVertexSequence()
315    # Fix up the ordering dependencies that the sequence didn't
316    # satisfy.
317    if self.version == 1:
318      self.RemoveBackwardEdges()
319    else:
320      self.ReverseBackwardEdges()
321      self.ImproveVertexSequence()
322
323    # Ensure the runtime stash size is under the limit.
324    if self.version >= 2 and common.OPTIONS.cache_size is not None:
325      self.ReviseStashSize()
326
327    # Double-check our work.
328    self.AssertSequenceGood()
329
330    self.ComputePatches(prefix)
331    self.WriteTransfers(prefix)
332
333  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
334    data = source.ReadRangeSet(ranges)
335    ctx = sha1()
336
337    for p in data:
338      ctx.update(p)
339
340    return ctx.hexdigest()
341
342  def WriteTransfers(self, prefix):
343    out = []
344
345    total = 0
346
347    stashes = {}
348    stashed_blocks = 0
349    max_stashed_blocks = 0
350
351    free_stash_ids = []
352    next_stash_id = 0
353
354    for xf in self.transfers:
355
356      if self.version < 2:
357        assert not xf.stash_before
358        assert not xf.use_stash
359
360      for s, sr in xf.stash_before:
361        assert s not in stashes
362        if free_stash_ids:
363          sid = heapq.heappop(free_stash_ids)
364        else:
365          sid = next_stash_id
366          next_stash_id += 1
367        stashes[s] = sid
368        if self.version == 2:
369          stashed_blocks += sr.size()
370          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
371        else:
372          sh = self.HashBlocks(self.src, sr)
373          if sh in stashes:
374            stashes[sh] += 1
375          else:
376            stashes[sh] = 1
377            stashed_blocks += sr.size()
378            self.touched_src_ranges = self.touched_src_ranges.union(sr)
379            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
380
381      if stashed_blocks > max_stashed_blocks:
382        max_stashed_blocks = stashed_blocks
383
384      free_string = []
385      free_size = 0
386
387      if self.version == 1:
388        src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else ""
389      elif self.version >= 2:
390
391        #   <# blocks> <src ranges>
392        #     OR
393        #   <# blocks> <src ranges> <src locs> <stash refs...>
394        #     OR
395        #   <# blocks> - <stash refs...>
396
397        size = xf.src_ranges.size()
398        src_str = [str(size)]
399
400        unstashed_src_ranges = xf.src_ranges
401        mapped_stashes = []
402        for s, sr in xf.use_stash:
403          sid = stashes.pop(s)
404          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
405          sh = self.HashBlocks(self.src, sr)
406          sr = xf.src_ranges.map_within(sr)
407          mapped_stashes.append(sr)
408          if self.version == 2:
409            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
410            # A stash will be used only once. We need to free the stash
411            # immediately after the use, instead of waiting for the automatic
412            # clean-up at the end. Because otherwise it may take up extra space
413            # and lead to OTA failures.
414            # Bug: 23119955
415            free_string.append("free %d\n" % (sid,))
416            free_size += sr.size()
417          else:
418            assert sh in stashes
419            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
420            stashes[sh] -= 1
421            if stashes[sh] == 0:
422              free_size += sr.size()
423              free_string.append("free %s\n" % (sh))
424              stashes.pop(sh)
425          heapq.heappush(free_stash_ids, sid)
426
427        if unstashed_src_ranges:
428          src_str.insert(1, unstashed_src_ranges.to_string_raw())
429          if xf.use_stash:
430            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
431            src_str.insert(2, mapped_unstashed.to_string_raw())
432            mapped_stashes.append(mapped_unstashed)
433            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
434        else:
435          src_str.insert(1, "-")
436          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
437
438        src_str = " ".join(src_str)
439
440      # all versions:
441      #   zero <rangeset>
442      #   new <rangeset>
443      #   erase <rangeset>
444      #
445      # version 1:
446      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
447      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
448      #   move <src rangeset> <tgt rangeset>
449      #
450      # version 2:
451      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
452      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
453      #   move <tgt rangeset> <src_str>
454      #
455      # version 3:
456      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
457      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
458      #   move hash <tgt rangeset> <src_str>
459
460      tgt_size = xf.tgt_ranges.size()
461
462      if xf.style == "new":
463        assert xf.tgt_ranges
464        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
465        total += tgt_size
466      elif xf.style == "move":
467        assert xf.tgt_ranges
468        assert xf.src_ranges.size() == tgt_size
469        if xf.src_ranges != xf.tgt_ranges:
470          if self.version == 1:
471            out.append("%s %s %s\n" % (
472                xf.style,
473                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
474          elif self.version == 2:
475            out.append("%s %s %s\n" % (
476                xf.style,
477                xf.tgt_ranges.to_string_raw(), src_str))
478          elif self.version >= 3:
479            # take into account automatic stashing of overlapping blocks
480            if xf.src_ranges.overlaps(xf.tgt_ranges):
481              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
482              if temp_stash_usage > max_stashed_blocks:
483                max_stashed_blocks = temp_stash_usage
484
485            self.touched_src_ranges = self.touched_src_ranges.union(
486                xf.src_ranges)
487
488            out.append("%s %s %s %s\n" % (
489                xf.style,
490                self.HashBlocks(self.tgt, xf.tgt_ranges),
491                xf.tgt_ranges.to_string_raw(), src_str))
492          total += tgt_size
493      elif xf.style in ("bsdiff", "imgdiff"):
494        assert xf.tgt_ranges
495        assert xf.src_ranges
496        if self.version == 1:
497          out.append("%s %d %d %s %s\n" % (
498              xf.style, xf.patch_start, xf.patch_len,
499              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
500        elif self.version == 2:
501          out.append("%s %d %d %s %s\n" % (
502              xf.style, xf.patch_start, xf.patch_len,
503              xf.tgt_ranges.to_string_raw(), src_str))
504        elif self.version >= 3:
505          # take into account automatic stashing of overlapping blocks
506          if xf.src_ranges.overlaps(xf.tgt_ranges):
507            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
508            if temp_stash_usage > max_stashed_blocks:
509              max_stashed_blocks = temp_stash_usage
510
511          self.touched_src_ranges = self.touched_src_ranges.union(
512              xf.src_ranges)
513
514          out.append("%s %d %d %s %s %s %s\n" % (
515              xf.style,
516              xf.patch_start, xf.patch_len,
517              self.HashBlocks(self.src, xf.src_ranges),
518              self.HashBlocks(self.tgt, xf.tgt_ranges),
519              xf.tgt_ranges.to_string_raw(), src_str))
520        total += tgt_size
521      elif xf.style == "zero":
522        assert xf.tgt_ranges
523        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
524        if to_zero:
525          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
526          total += to_zero.size()
527      else:
528        raise ValueError("unknown transfer style '%s'\n" % xf.style)
529
530      if free_string:
531        out.append("".join(free_string))
532        stashed_blocks -= free_size
533
534      if self.version >= 2 and common.OPTIONS.cache_size is not None:
535        # Sanity check: abort if we're going to need more stash space than
536        # the allowed size (cache_size * threshold). There are two purposes
537        # of having a threshold here. a) Part of the cache may have been
538        # occupied by some recovery logs. b) It will buy us some time to deal
539        # with the oversize issue.
540        cache_size = common.OPTIONS.cache_size
541        stash_threshold = common.OPTIONS.stash_threshold
542        max_allowed = cache_size * stash_threshold
543        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
544               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
545                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
546                   self.tgt.blocksize, max_allowed, cache_size,
547                   stash_threshold)
548
549    if self.version >= 3:
550      self.touched_src_sha1 = self.HashBlocks(
551          self.src, self.touched_src_ranges)
552
553    # Zero out extended blocks as a workaround for bug 20881595.
554    if self.tgt.extended:
555      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
556      total += self.tgt.extended.size()
557
558    # We erase all the blocks on the partition that a) don't contain useful
559    # data in the new image; b) will not be touched by dm-verity. Out of those
560    # blocks, we erase the ones that won't be used in this update at the
561    # beginning of an update. The rest would be erased at the end. This is to
562    # work around the eMMC issue observed on some devices, which may otherwise
563    # get starving for clean blocks and thus fail the update. (b/28347095)
564    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
565    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
566    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
567
568    erase_first = new_dontcare.subtract(self.touched_src_ranges)
569    if erase_first:
570      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
571
572    erase_last = new_dontcare.subtract(erase_first)
573    if erase_last:
574      out.append("erase %s\n" % (erase_last.to_string_raw(),))
575
576    out.insert(0, "%d\n" % (self.version,))   # format version number
577    out.insert(1, "%d\n" % (total,))
578    if self.version >= 2:
579      # version 2 only: after the total block count, we give the number
580      # of stash slots needed, and the maximum size needed (in blocks)
581      out.insert(2, str(next_stash_id) + "\n")
582      out.insert(3, str(max_stashed_blocks) + "\n")
583
584    with open(prefix + ".transfer.list", "wb") as f:
585      for i in out:
586        f.write(i)
587
588    if self.version >= 2:
589      self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
590      OPTIONS = common.OPTIONS
591      if OPTIONS.cache_size is not None:
592        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
593        print("max stashed blocks: %d  (%d bytes), "
594              "limit: %d bytes (%.2f%%)\n" % (
595              max_stashed_blocks, self._max_stashed_size, max_allowed,
596              self._max_stashed_size * 100.0 / max_allowed))
597      else:
598        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
599              max_stashed_blocks, self._max_stashed_size))
600
601  def ReviseStashSize(self):
602    print("Revising stash size...")
603    stashes = {}
604
605    # Create the map between a stash and its def/use points. For example, for a
606    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
607    for xf in self.transfers:
608      # Command xf defines (stores) all the stashes in stash_before.
609      for idx, sr in xf.stash_before:
610        stashes[idx] = (sr, xf)
611
612      # Record all the stashes command xf uses.
613      for idx, _ in xf.use_stash:
614        stashes[idx] += (xf,)
615
616    # Compute the maximum blocks available for stash based on /cache size and
617    # the threshold.
618    cache_size = common.OPTIONS.cache_size
619    stash_threshold = common.OPTIONS.stash_threshold
620    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
621
622    stashed_blocks = 0
623    new_blocks = 0
624
625    # Now go through all the commands. Compute the required stash size on the
626    # fly. If a command requires excess stash than available, it deletes the
627    # stash by replacing the command that uses the stash with a "new" command
628    # instead.
629    for xf in self.transfers:
630      replaced_cmds = []
631
632      # xf.stash_before generates explicit stash commands.
633      for idx, sr in xf.stash_before:
634        if stashed_blocks + sr.size() > max_allowed:
635          # We cannot stash this one for a later command. Find out the command
636          # that will use this stash and replace the command with "new".
637          use_cmd = stashes[idx][2]
638          replaced_cmds.append(use_cmd)
639          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
640        else:
641          stashed_blocks += sr.size()
642
643      # xf.use_stash generates free commands.
644      for _, sr in xf.use_stash:
645        stashed_blocks -= sr.size()
646
647      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
648      # ComputePatches(), they both have the style of "diff".
649      if xf.style == "diff" and self.version >= 3:
650        assert xf.tgt_ranges and xf.src_ranges
651        if xf.src_ranges.overlaps(xf.tgt_ranges):
652          if stashed_blocks + xf.src_ranges.size() > max_allowed:
653            replaced_cmds.append(xf)
654            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
655
656      # Replace the commands in replaced_cmds with "new"s.
657      for cmd in replaced_cmds:
658        # It no longer uses any commands in "use_stash". Remove the def points
659        # for all those stashes.
660        for idx, sr in cmd.use_stash:
661          def_cmd = stashes[idx][1]
662          assert (idx, sr) in def_cmd.stash_before
663          def_cmd.stash_before.remove((idx, sr))
664
665        # Add up blocks that violates space limit and print total number to
666        # screen later.
667        new_blocks += cmd.tgt_ranges.size()
668        cmd.ConvertToNew()
669
670    num_of_bytes = new_blocks * self.tgt.blocksize
671    print("  Total %d blocks (%d bytes) are packed as new blocks due to "
672          "insufficient cache size." % (new_blocks, num_of_bytes))
673
674  def ComputePatches(self, prefix):
675    print("Reticulating splines...")
676    diff_q = []
677    patch_num = 0
678    with open(prefix + ".new.dat", "wb") as new_f:
679      for xf in self.transfers:
680        if xf.style == "zero":
681          pass
682        elif xf.style == "new":
683          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
684            new_f.write(piece)
685        elif xf.style == "diff":
686          src = self.src.ReadRangeSet(xf.src_ranges)
687          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
688
689          # We can't compare src and tgt directly because they may have
690          # the same content but be broken up into blocks differently, eg:
691          #
692          #    ["he", "llo"]  vs  ["h", "ello"]
693          #
694          # We want those to compare equal, ideally without having to
695          # actually concatenate the strings (these may be tens of
696          # megabytes).
697
698          src_sha1 = sha1()
699          for p in src:
700            src_sha1.update(p)
701          tgt_sha1 = sha1()
702          tgt_size = 0
703          for p in tgt:
704            tgt_sha1.update(p)
705            tgt_size += len(p)
706
707          if src_sha1.digest() == tgt_sha1.digest():
708            # These are identical; we don't need to generate a patch,
709            # just issue copy commands on the device.
710            xf.style = "move"
711          else:
712            # For files in zip format (eg, APKs, JARs, etc.) we would
713            # like to use imgdiff -z if possible (because it usually
714            # produces significantly smaller patches than bsdiff).
715            # This is permissible if:
716            #
717            #  - the source and target files are monotonic (ie, the
718            #    data is stored with blocks in increasing order), and
719            #  - we haven't removed any blocks from the source set.
720            #
721            # If these conditions are satisfied then appending all the
722            # blocks in the set together in order will produce a valid
723            # zip file (plus possibly extra zeros in the last block),
724            # which is what imgdiff needs to operate.  (imgdiff is
725            # fine with extra zeros at the end of the file.)
726            imgdiff = (xf.intact and
727                       xf.tgt_name.split(".")[-1].lower()
728                       in ("apk", "jar", "zip"))
729            xf.style = "imgdiff" if imgdiff else "bsdiff"
730            diff_q.append((tgt_size, src, tgt, xf, patch_num))
731            patch_num += 1
732
733        else:
734          assert False, "unknown style " + xf.style
735
736    if diff_q:
737      if self.threads > 1:
738        print("Computing patches (using %d threads)..." % (self.threads,))
739      else:
740        print("Computing patches...")
741      diff_q.sort()
742
743      patches = [None] * patch_num
744
745      # TODO: Rewrite with multiprocessing.ThreadPool?
746      lock = threading.Lock()
747      def diff_worker():
748        while True:
749          with lock:
750            if not diff_q:
751              return
752            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
753          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
754          size = len(patch)
755          with lock:
756            patches[patchnum] = (patch, xf)
757            print("%10d %10d (%6.2f%%) %7s %s" % (
758                size, tgt_size, size * 100.0 / tgt_size, xf.style,
759                xf.tgt_name if xf.tgt_name == xf.src_name else (
760                    xf.tgt_name + " (from " + xf.src_name + ")")))
761
762      threads = [threading.Thread(target=diff_worker)
763                 for _ in range(self.threads)]
764      for th in threads:
765        th.start()
766      while threads:
767        threads.pop().join()
768    else:
769      patches = []
770
771    p = 0
772    with open(prefix + ".patch.dat", "wb") as patch_f:
773      for patch, xf in patches:
774        xf.patch_start = p
775        xf.patch_len = len(patch)
776        patch_f.write(patch)
777        p += len(patch)
778
779  def AssertSequenceGood(self):
780    # Simulate the sequences of transfers we will output, and check that:
781    # - we never read a block after writing it, and
782    # - we write every block we care about exactly once.
783
784    # Start with no blocks having been touched yet.
785    touched = array.array("B", "\0" * self.tgt.total_blocks)
786
787    # Imagine processing the transfers in order.
788    for xf in self.transfers:
789      # Check that the input blocks for this transfer haven't yet been touched.
790
791      x = xf.src_ranges
792      if self.version >= 2:
793        for _, sr in xf.use_stash:
794          x = x.subtract(sr)
795
796      for s, e in x:
797        # Source image could be larger. Don't check the blocks that are in the
798        # source image only. Since they are not in 'touched', and won't ever
799        # be touched.
800        for i in range(s, min(e, self.tgt.total_blocks)):
801          assert touched[i] == 0
802
803      # Check that the output blocks for this transfer haven't yet
804      # been touched, and touch all the blocks written by this
805      # transfer.
806      for s, e in xf.tgt_ranges:
807        for i in range(s, e):
808          assert touched[i] == 0
809          touched[i] = 1
810
811    # Check that we've written every target block.
812    for s, e in self.tgt.care_map:
813      for i in range(s, e):
814        assert touched[i] == 1
815
816  def ImproveVertexSequence(self):
817    print("Improving vertex order...")
818
819    # At this point our digraph is acyclic; we reversed any edges that
820    # were backwards in the heuristically-generated sequence.  The
821    # previously-generated order is still acceptable, but we hope to
822    # find a better order that needs less memory for stashed data.
823    # Now we do a topological sort to generate a new vertex order,
824    # using a greedy algorithm to choose which vertex goes next
825    # whenever we have a choice.
826
827    # Make a copy of the edge set; this copy will get destroyed by the
828    # algorithm.
829    for xf in self.transfers:
830      xf.incoming = xf.goes_after.copy()
831      xf.outgoing = xf.goes_before.copy()
832
833    L = []   # the new vertex order
834
835    # S is the set of sources in the remaining graph; we always choose
836    # the one that leaves the least amount of stashed data after it's
837    # executed.
838    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
839         if not u.incoming]
840    heapq.heapify(S)
841
842    while S:
843      _, _, xf = heapq.heappop(S)
844      L.append(xf)
845      for u in xf.outgoing:
846        del u.incoming[xf]
847        if not u.incoming:
848          heapq.heappush(S, (u.NetStashChange(), u.order, u))
849
850    # if this fails then our graph had a cycle.
851    assert len(L) == len(self.transfers)
852
853    self.transfers = L
854    for i, xf in enumerate(L):
855      xf.order = i
856
857  def RemoveBackwardEdges(self):
858    print("Removing backward edges...")
859    in_order = 0
860    out_of_order = 0
861    lost_source = 0
862
863    for xf in self.transfers:
864      lost = 0
865      size = xf.src_ranges.size()
866      for u in xf.goes_before:
867        # xf should go before u
868        if xf.order < u.order:
869          # it does, hurray!
870          in_order += 1
871        else:
872          # it doesn't, boo.  trim the blocks that u writes from xf's
873          # source, so that xf can go after u.
874          out_of_order += 1
875          assert xf.src_ranges.overlaps(u.tgt_ranges)
876          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
877          xf.intact = False
878
879      if xf.style == "diff" and not xf.src_ranges:
880        # nothing left to diff from; treat as new data
881        xf.style = "new"
882
883      lost = size - xf.src_ranges.size()
884      lost_source += lost
885
886    print(("  %d/%d dependencies (%.2f%%) were violated; "
887           "%d source blocks removed.") %
888          (out_of_order, in_order + out_of_order,
889           (out_of_order * 100.0 / (in_order + out_of_order))
890           if (in_order + out_of_order) else 0.0,
891           lost_source))
892
893  def ReverseBackwardEdges(self):
894    print("Reversing backward edges...")
895    in_order = 0
896    out_of_order = 0
897    stashes = 0
898    stash_size = 0
899
900    for xf in self.transfers:
901      for u in xf.goes_before.copy():
902        # xf should go before u
903        if xf.order < u.order:
904          # it does, hurray!
905          in_order += 1
906        else:
907          # it doesn't, boo.  modify u to stash the blocks that it
908          # writes that xf wants to read, and then require u to go
909          # before xf.
910          out_of_order += 1
911
912          overlap = xf.src_ranges.intersect(u.tgt_ranges)
913          assert overlap
914
915          u.stash_before.append((stashes, overlap))
916          xf.use_stash.append((stashes, overlap))
917          stashes += 1
918          stash_size += overlap.size()
919
920          # reverse the edge direction; now xf must go after u
921          del xf.goes_before[u]
922          del u.goes_after[xf]
923          xf.goes_after[u] = None    # value doesn't matter
924          u.goes_before[xf] = None
925
926    print(("  %d/%d dependencies (%.2f%%) were violated; "
927           "%d source blocks stashed.") %
928          (out_of_order, in_order + out_of_order,
929           (out_of_order * 100.0 / (in_order + out_of_order))
930           if (in_order + out_of_order) else 0.0,
931           stash_size))
932
933  def FindVertexSequence(self):
934    print("Finding vertex sequence...")
935
936    # This is based on "A Fast & Effective Heuristic for the Feedback
937    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
938    # it as starting with the digraph G and moving all the vertices to
939    # be on a horizontal line in some order, trying to minimize the
940    # number of edges that end up pointing to the left.  Left-pointing
941    # edges will get removed to turn the digraph into a DAG.  In this
942    # case each edge has a weight which is the number of source blocks
943    # we'll lose if that edge is removed; we try to minimize the total
944    # weight rather than just the number of edges.
945
946    # Make a copy of the edge set; this copy will get destroyed by the
947    # algorithm.
948    for xf in self.transfers:
949      xf.incoming = xf.goes_after.copy()
950      xf.outgoing = xf.goes_before.copy()
951      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
952
953    # We use an OrderedDict instead of just a set so that the output
954    # is repeatable; otherwise it would depend on the hash values of
955    # the transfer objects.
956    G = OrderedDict()
957    for xf in self.transfers:
958      G[xf] = None
959    s1 = deque()  # the left side of the sequence, built from left to right
960    s2 = deque()  # the right side of the sequence, built from right to left
961
962    heap = []
963    for xf in self.transfers:
964      xf.heap_item = HeapItem(xf)
965      heap.append(xf.heap_item)
966    heapq.heapify(heap)
967
968    sinks = set(u for u in G if not u.outgoing)
969    sources = set(u for u in G if not u.incoming)
970
971    def adjust_score(iu, delta):
972      iu.score += delta
973      iu.heap_item.clear()
974      iu.heap_item = HeapItem(iu)
975      heapq.heappush(heap, iu.heap_item)
976
977    while G:
978      # Put all sinks at the end of the sequence.
979      while sinks:
980        new_sinks = set()
981        for u in sinks:
982          if u not in G: continue
983          s2.appendleft(u)
984          del G[u]
985          for iu in u.incoming:
986            adjust_score(iu, -iu.outgoing.pop(u))
987            if not iu.outgoing: new_sinks.add(iu)
988        sinks = new_sinks
989
990      # Put all the sources at the beginning of the sequence.
991      while sources:
992        new_sources = set()
993        for u in sources:
994          if u not in G: continue
995          s1.append(u)
996          del G[u]
997          for iu in u.outgoing:
998            adjust_score(iu, +iu.incoming.pop(u))
999            if not iu.incoming: new_sources.add(iu)
1000        sources = new_sources
1001
1002      if not G: break
1003
1004      # Find the "best" vertex to put next.  "Best" is the one that
1005      # maximizes the net difference in source blocks saved we get by
1006      # pretending it's a source rather than a sink.
1007
1008      while True:
1009        u = heapq.heappop(heap)
1010        if u and u.item in G:
1011          u = u.item
1012          break
1013
1014      s1.append(u)
1015      del G[u]
1016      for iu in u.outgoing:
1017        adjust_score(iu, +iu.incoming.pop(u))
1018        if not iu.incoming: sources.add(iu)
1019
1020      for iu in u.incoming:
1021        adjust_score(iu, -iu.outgoing.pop(u))
1022        if not iu.outgoing: sinks.add(iu)
1023
1024    # Now record the sequence in the 'order' field of each transfer,
1025    # and by rearranging self.transfers to be in the chosen sequence.
1026
1027    new_transfers = []
1028    for x in itertools.chain(s1, s2):
1029      x.order = len(new_transfers)
1030      new_transfers.append(x)
1031      del x.incoming
1032      del x.outgoing
1033
1034    self.transfers = new_transfers
1035
1036  def GenerateDigraph(self):
1037    print("Generating digraph...")
1038
1039    # Each item of source_ranges will be:
1040    #   - None, if that block is not used as a source,
1041    #   - a transfer, if one transfer uses it as a source, or
1042    #   - a set of transfers.
1043    source_ranges = []
1044    for b in self.transfers:
1045      for s, e in b.src_ranges:
1046        if e > len(source_ranges):
1047          source_ranges.extend([None] * (e-len(source_ranges)))
1048        for i in range(s, e):
1049          if source_ranges[i] is None:
1050            source_ranges[i] = b
1051          else:
1052            if not isinstance(source_ranges[i], set):
1053              source_ranges[i] = set([source_ranges[i]])
1054            source_ranges[i].add(b)
1055
1056    for a in self.transfers:
1057      intersections = set()
1058      for s, e in a.tgt_ranges:
1059        for i in range(s, e):
1060          if i >= len(source_ranges): break
1061          b = source_ranges[i]
1062          if b is not None:
1063            if isinstance(b, set):
1064              intersections.update(b)
1065            else:
1066              intersections.add(b)
1067
1068      for b in intersections:
1069        if a is b: continue
1070
1071        # If the blocks written by A are read by B, then B needs to go before A.
1072        i = a.tgt_ranges.intersect(b.src_ranges)
1073        if i:
1074          if b.src_name == "__ZERO":
1075            # the cost of removing source blocks for the __ZERO domain
1076            # is (nearly) zero.
1077            size = 0
1078          else:
1079            size = i.size()
1080          b.goes_before[a] = size
1081          a.goes_after[b] = size
1082
1083  def FindTransfers(self):
1084    """Parse the file_map to generate all the transfers."""
1085
1086    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1087                    split=False):
1088      """Wrapper function for adding a Transfer().
1089
1090      For BBOTA v3, we need to stash source blocks for resumable feature.
1091      However, with the growth of file size and the shrink of the cache
1092      partition source blocks are too large to be stashed. If a file occupies
1093      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
1094      into smaller pieces by getting multiple Transfer()s.
1095
1096      The downside is that after splitting, we may increase the package size
1097      since the split pieces don't align well. According to our experiments,
1098      1/8 of the cache size as the per-piece limit appears to be optimal.
1099      Compared to the fixed 1024-block limit, it reduces the overall package
1100      size by 30% volantis, and 20% for angler and bullhead."""
1101
1102      # We care about diff transfers only.
1103      if style != "diff" or not split:
1104        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1105        return
1106
1107      pieces = 0
1108      cache_size = common.OPTIONS.cache_size
1109      split_threshold = 0.125
1110      max_blocks_per_transfer = int(cache_size * split_threshold /
1111                                    self.tgt.blocksize)
1112
1113      # Change nothing for small files.
1114      if (tgt_ranges.size() <= max_blocks_per_transfer and
1115          src_ranges.size() <= max_blocks_per_transfer):
1116        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1117        return
1118
1119      while (tgt_ranges.size() > max_blocks_per_transfer and
1120             src_ranges.size() > max_blocks_per_transfer):
1121        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1122        src_split_name = "%s-%d" % (src_name, pieces)
1123        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1124        src_first = src_ranges.first(max_blocks_per_transfer)
1125
1126        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
1127                 by_id)
1128
1129        tgt_ranges = tgt_ranges.subtract(tgt_first)
1130        src_ranges = src_ranges.subtract(src_first)
1131        pieces += 1
1132
1133      # Handle remaining blocks.
1134      if tgt_ranges.size() or src_ranges.size():
1135        # Must be both non-empty.
1136        assert tgt_ranges.size() and src_ranges.size()
1137        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1138        src_split_name = "%s-%d" % (src_name, pieces)
1139        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1140                 by_id)
1141
1142    empty = RangeSet()
1143    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1144      if tgt_fn == "__ZERO":
1145        # the special "__ZERO" domain is all the blocks not contained
1146        # in any file and that are filled with zeros.  We have a
1147        # special transfer style for zero blocks.
1148        src_ranges = self.src.file_map.get("__ZERO", empty)
1149        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1150                    "zero", self.transfers)
1151        continue
1152
1153      elif tgt_fn == "__COPY":
1154        # "__COPY" domain includes all the blocks not contained in any
1155        # file and that need to be copied unconditionally to the target.
1156        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1157        continue
1158
1159      elif tgt_fn in self.src.file_map:
1160        # Look for an exact pathname match in the source.
1161        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1162                    "diff", self.transfers, self.version >= 3)
1163        continue
1164
1165      b = os.path.basename(tgt_fn)
1166      if b in self.src_basenames:
1167        # Look for an exact basename match in the source.
1168        src_fn = self.src_basenames[b]
1169        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1170                    "diff", self.transfers, self.version >= 3)
1171        continue
1172
1173      b = re.sub("[0-9]+", "#", b)
1174      if b in self.src_numpatterns:
1175        # Look for a 'number pattern' match (a basename match after
1176        # all runs of digits are replaced by "#").  (This is useful
1177        # for .so files that contain version numbers in the filename
1178        # that get bumped.)
1179        src_fn = self.src_numpatterns[b]
1180        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1181                    "diff", self.transfers, self.version >= 3)
1182        continue
1183
1184      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1185
1186  def AbbreviateSourceNames(self):
1187    for k in self.src.file_map.keys():
1188      b = os.path.basename(k)
1189      self.src_basenames[b] = k
1190      b = re.sub("[0-9]+", "#", b)
1191      self.src_numpatterns[b] = k
1192
1193  @staticmethod
1194  def AssertPartition(total, seq):
1195    """Assert that all the RangeSets in 'seq' form a partition of the
1196    'total' RangeSet (ie, they are nonintersecting and their union
1197    equals 'total')."""
1198
1199    so_far = RangeSet()
1200    for i in seq:
1201      assert not so_far.overlaps(i)
1202      so_far = so_far.union(i)
1203    assert so_far == total
1204