blockimgdiff.py revision fc44a515d46e6f4d5eaa0d32659b1cf3b9492305
1from __future__ import print_function
2
3from collections import deque, OrderedDict
4from hashlib import sha1
5import itertools
6import multiprocessing
7import os
8import pprint
9import re
10import subprocess
11import sys
12import threading
13import tempfile
14
15from rangelib import *
16
17def compute_patch(src, tgt, imgdiff=False):
18  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
19  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
20  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
21  os.close(patchfd)
22
23  try:
24    with os.fdopen(srcfd, "wb") as f_src:
25      for p in src:
26        f_src.write(p)
27
28    with os.fdopen(tgtfd, "wb") as f_tgt:
29      for p in tgt:
30        f_tgt.write(p)
31    try:
32      os.unlink(patchfile)
33    except OSError:
34      pass
35    if imgdiff:
36      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
37                          stdout=open("/dev/null", "a"),
38                          stderr=subprocess.STDOUT)
39    else:
40      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
41
42    if p:
43      raise ValueError("diff failed: " + str(p))
44
45    with open(patchfile, "rb") as f:
46      return f.read()
47  finally:
48    try:
49      os.unlink(srcfile)
50      os.unlink(tgtfile)
51      os.unlink(patchfile)
52    except OSError:
53      pass
54
55class EmptyImage(object):
56  """A zero-length image."""
57  blocksize = 4096
58  care_map = RangeSet()
59  total_blocks = 0
60  file_map = {}
61  def ReadRangeSet(self, ranges):
62    return ()
63
64class Transfer(object):
65  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
66    self.tgt_name = tgt_name
67    self.src_name = src_name
68    self.tgt_ranges = tgt_ranges
69    self.src_ranges = src_ranges
70    self.style = style
71    self.intact = (getattr(tgt_ranges, "monotonic", False) and
72                   getattr(src_ranges, "monotonic", False))
73    self.goes_before = {}
74    self.goes_after = {}
75
76    self.id = len(by_id)
77    by_id.append(self)
78
79  def __str__(self):
80    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
81            " to " + str(self.tgt_ranges) + ">")
82
83
84# BlockImageDiff works on two image objects.  An image object is
85# anything that provides the following attributes:
86#
87#    blocksize: the size in bytes of a block, currently must be 4096.
88#
89#    total_blocks: the total size of the partition/image, in blocks.
90#
91#    care_map: a RangeSet containing which blocks (in the range [0,
92#      total_blocks) we actually care about; i.e. which blocks contain
93#      data.
94#
95#    file_map: a dict that partitions the blocks contained in care_map
96#      into smaller domains that are useful for doing diffs on.
97#      (Typically a domain is a file, and the key in file_map is the
98#      pathname.)
99#
100#    ReadRangeSet(): a function that takes a RangeSet and returns the
101#      data contained in the image blocks of that RangeSet.  The data
102#      is returned as a list or tuple of strings; concatenating the
103#      elements together should produce the requested data.
104#      Implementations are free to break up the data into list/tuple
105#      elements in any way that is convenient.
106#
107# When creating a BlockImageDiff, the src image may be None, in which
108# case the list of transfers produced will never read from the
109# original image.
110
111class BlockImageDiff(object):
112  def __init__(self, tgt, src=None, threads=None):
113    if threads is None:
114      threads = multiprocessing.cpu_count() // 2
115      if threads == 0: threads = 1
116    self.threads = threads
117
118    self.tgt = tgt
119    if src is None:
120      src = EmptyImage()
121    self.src = src
122
123    # The updater code that installs the patch always uses 4k blocks.
124    assert tgt.blocksize == 4096
125    assert src.blocksize == 4096
126
127    # The range sets in each filemap should comprise a partition of
128    # the care map.
129    self.AssertPartition(src.care_map, src.file_map.values())
130    self.AssertPartition(tgt.care_map, tgt.file_map.values())
131
132  def Compute(self, prefix):
133    # When looking for a source file to use as the diff input for a
134    # target file, we try:
135    #   1) an exact path match if available, otherwise
136    #   2) a exact basename match if available, otherwise
137    #   3) a basename match after all runs of digits are replaced by
138    #      "#" if available, otherwise
139    #   4) we have no source for this target.
140    self.AbbreviateSourceNames()
141    self.FindTransfers()
142
143    # Find the ordering dependencies among transfers (this is O(n^2)
144    # in the number of transfers).
145    self.GenerateDigraph()
146    # Find a sequence of transfers that satisfies as many ordering
147    # dependencies as possible (heuristically).
148    self.FindVertexSequence()
149    # Fix up the ordering dependencies that the sequence didn't
150    # satisfy.
151    self.RemoveBackwardEdges()
152    # Double-check our work.
153    self.AssertSequenceGood()
154
155    self.ComputePatches(prefix)
156    self.WriteTransfers(prefix)
157
158  def WriteTransfers(self, prefix):
159    out = []
160
161    out.append("1\n")   # format version number
162    total = 0
163    performs_read = False
164
165    for xf in self.transfers:
166
167      # zero [rangeset]
168      # new [rangeset]
169      # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
170      # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
171      # move [src rangeset] [tgt rangeset]
172      # erase [rangeset]
173
174      tgt_size = xf.tgt_ranges.size()
175
176      if xf.style == "new":
177        assert xf.tgt_ranges
178        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
179        total += tgt_size
180      elif xf.style == "move":
181        performs_read = True
182        assert xf.tgt_ranges
183        assert xf.src_ranges.size() == tgt_size
184        if xf.src_ranges != xf.tgt_ranges:
185          out.append("%s %s %s\n" % (
186              xf.style,
187              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
188          total += tgt_size
189      elif xf.style in ("bsdiff", "imgdiff"):
190        performs_read = True
191        assert xf.tgt_ranges
192        assert xf.src_ranges
193        out.append("%s %d %d %s %s\n" % (
194            xf.style, xf.patch_start, xf.patch_len,
195            xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
196        total += tgt_size
197      elif xf.style == "zero":
198        assert xf.tgt_ranges
199        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
200        if to_zero:
201          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
202          total += to_zero.size()
203      else:
204        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
205
206    out.insert(1, str(total) + "\n")
207
208    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
209    if performs_read:
210      # if some of the original data is used, then at the end we'll
211      # erase all the blocks on the partition that don't contain data
212      # in the new image.
213      new_dontcare = all_tgt.subtract(self.tgt.care_map)
214      if new_dontcare:
215        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
216    else:
217      # if nothing is read (ie, this is a full OTA), then we can start
218      # by erasing the entire partition.
219      out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
220
221    with open(prefix + ".transfer.list", "wb") as f:
222      for i in out:
223        f.write(i)
224
225  def ComputePatches(self, prefix):
226    print("Reticulating splines...")
227    diff_q = []
228    patch_num = 0
229    with open(prefix + ".new.dat", "wb") as new_f:
230      for xf in self.transfers:
231        if xf.style == "zero":
232          pass
233        elif xf.style == "new":
234          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
235            new_f.write(piece)
236        elif xf.style == "diff":
237          src = self.src.ReadRangeSet(xf.src_ranges)
238          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
239
240          # We can't compare src and tgt directly because they may have
241          # the same content but be broken up into blocks differently, eg:
242          #
243          #    ["he", "llo"]  vs  ["h", "ello"]
244          #
245          # We want those to compare equal, ideally without having to
246          # actually concatenate the strings (these may be tens of
247          # megabytes).
248
249          src_sha1 = sha1()
250          for p in src:
251            src_sha1.update(p)
252          tgt_sha1 = sha1()
253          tgt_size = 0
254          for p in tgt:
255            tgt_sha1.update(p)
256            tgt_size += len(p)
257
258          if src_sha1.digest() == tgt_sha1.digest():
259            # These are identical; we don't need to generate a patch,
260            # just issue copy commands on the device.
261            xf.style = "move"
262          else:
263            # For files in zip format (eg, APKs, JARs, etc.) we would
264            # like to use imgdiff -z if possible (because it usually
265            # produces significantly smaller patches than bsdiff).
266            # This is permissible if:
267            #
268            #  - the source and target files are monotonic (ie, the
269            #    data is stored with blocks in increasing order), and
270            #  - we haven't removed any blocks from the source set.
271            #
272            # If these conditions are satisfied then appending all the
273            # blocks in the set together in order will produce a valid
274            # zip file (plus possibly extra zeros in the last block),
275            # which is what imgdiff needs to operate.  (imgdiff is
276            # fine with extra zeros at the end of the file.)
277            imgdiff = (xf.intact and
278                       xf.tgt_name.split(".")[-1].lower()
279                       in ("apk", "jar", "zip"))
280            xf.style = "imgdiff" if imgdiff else "bsdiff"
281            diff_q.append((tgt_size, src, tgt, xf, patch_num))
282            patch_num += 1
283
284        else:
285          assert False, "unknown style " + xf.style
286
287    if diff_q:
288      if self.threads > 1:
289        print("Computing patches (using %d threads)..." % (self.threads,))
290      else:
291        print("Computing patches...")
292      diff_q.sort()
293
294      patches = [None] * patch_num
295
296      lock = threading.Lock()
297      def diff_worker():
298        while True:
299          with lock:
300            if not diff_q: return
301            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
302          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
303          size = len(patch)
304          with lock:
305            patches[patchnum] = (patch, xf)
306            print("%10d %10d (%6.2f%%) %7s %s" % (
307                size, tgt_size, size * 100.0 / tgt_size, xf.style,
308                xf.tgt_name if xf.tgt_name == xf.src_name else (
309                    xf.tgt_name + " (from " + xf.src_name + ")")))
310
311      threads = [threading.Thread(target=diff_worker)
312                 for i in range(self.threads)]
313      for th in threads:
314        th.start()
315      while threads:
316        threads.pop().join()
317    else:
318      patches = []
319
320    p = 0
321    with open(prefix + ".patch.dat", "wb") as patch_f:
322      for patch, xf in patches:
323        xf.patch_start = p
324        xf.patch_len = len(patch)
325        patch_f.write(patch)
326        p += len(patch)
327
328  def AssertSequenceGood(self):
329    # Simulate the sequences of transfers we will output, and check that:
330    # - we never read a block after writing it, and
331    # - we write every block we care about exactly once.
332
333    # Start with no blocks having been touched yet.
334    touched = RangeSet()
335
336    # Imagine processing the transfers in order.
337    for xf in self.transfers:
338      # Check that the input blocks for this transfer haven't yet been touched.
339      assert not touched.overlaps(xf.src_ranges)
340      # Check that the output blocks for this transfer haven't yet been touched.
341      assert not touched.overlaps(xf.tgt_ranges)
342      # Touch all the blocks written by this transfer.
343      touched = touched.union(xf.tgt_ranges)
344
345    # Check that we've written every target block.
346    assert touched == self.tgt.care_map
347
348  def RemoveBackwardEdges(self):
349    print("Removing backward edges...")
350    in_order = 0
351    out_of_order = 0
352    lost_source = 0
353
354    for xf in self.transfers:
355      io = 0
356      ooo = 0
357      lost = 0
358      size = xf.src_ranges.size()
359      for u in xf.goes_before:
360        # xf should go before u
361        if xf.order < u.order:
362          # it does, hurray!
363          io += 1
364        else:
365          # it doesn't, boo.  trim the blocks that u writes from xf's
366          # source, so that xf can go after u.
367          ooo += 1
368          assert xf.src_ranges.overlaps(u.tgt_ranges)
369          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
370          xf.intact = False
371
372      if xf.style == "diff" and not xf.src_ranges:
373        # nothing left to diff from; treat as new data
374        xf.style = "new"
375
376      lost = size - xf.src_ranges.size()
377      lost_source += lost
378      in_order += io
379      out_of_order += ooo
380
381    print(("  %d/%d dependencies (%.2f%%) were violated; "
382           "%d source blocks removed.") %
383          (out_of_order, in_order + out_of_order,
384           (out_of_order * 100.0 / (in_order + out_of_order))
385           if (in_order + out_of_order) else 0.0,
386           lost_source))
387
388  def FindVertexSequence(self):
389    print("Finding vertex sequence...")
390
391    # This is based on "A Fast & Effective Heuristic for the Feedback
392    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
393    # it as starting with the digraph G and moving all the vertices to
394    # be on a horizontal line in some order, trying to minimize the
395    # number of edges that end up pointing to the left.  Left-pointing
396    # edges will get removed to turn the digraph into a DAG.  In this
397    # case each edge has a weight which is the number of source blocks
398    # we'll lose if that edge is removed; we try to minimize the total
399    # weight rather than just the number of edges.
400
401    # Make a copy of the edge set; this copy will get destroyed by the
402    # algorithm.
403    for xf in self.transfers:
404      xf.incoming = xf.goes_after.copy()
405      xf.outgoing = xf.goes_before.copy()
406
407    # We use an OrderedDict instead of just a set so that the output
408    # is repeatable; otherwise it would depend on the hash values of
409    # the transfer objects.
410    G = OrderedDict()
411    for xf in self.transfers:
412      G[xf] = None
413    s1 = deque()  # the left side of the sequence, built from left to right
414    s2 = deque()  # the right side of the sequence, built from right to left
415
416    while G:
417
418      # Put all sinks at the end of the sequence.
419      while True:
420        sinks = [u for u in G if not u.outgoing]
421        if not sinks: break
422        for u in sinks:
423          s2.appendleft(u)
424          del G[u]
425          for iu in u.incoming:
426            del iu.outgoing[u]
427
428      # Put all the sources at the beginning of the sequence.
429      while True:
430        sources = [u for u in G if not u.incoming]
431        if not sources: break
432        for u in sources:
433          s1.append(u)
434          del G[u]
435          for iu in u.outgoing:
436            del iu.incoming[u]
437
438      if not G: break
439
440      # Find the "best" vertex to put next.  "Best" is the one that
441      # maximizes the net difference in source blocks saved we get by
442      # pretending it's a source rather than a sink.
443
444      max_d = None
445      best_u = None
446      for u in G:
447        d = sum(u.outgoing.values()) - sum(u.incoming.values())
448        if best_u is None or d > max_d:
449          max_d = d
450          best_u = u
451
452      u = best_u
453      s1.append(u)
454      del G[u]
455      for iu in u.outgoing:
456        del iu.incoming[u]
457      for iu in u.incoming:
458        del iu.outgoing[u]
459
460    # Now record the sequence in the 'order' field of each transfer,
461    # and by rearranging self.transfers to be in the chosen sequence.
462
463    new_transfers = []
464    for x in itertools.chain(s1, s2):
465      x.order = len(new_transfers)
466      new_transfers.append(x)
467      del x.incoming
468      del x.outgoing
469
470    self.transfers = new_transfers
471
472  def GenerateDigraph(self):
473    print("Generating digraph...")
474    for a in self.transfers:
475      for b in self.transfers:
476        if a is b: continue
477
478        # If the blocks written by A are read by B, then B needs to go before A.
479        i = a.tgt_ranges.intersect(b.src_ranges)
480        if i:
481          size = i.size()
482          b.goes_before[a] = size
483          a.goes_after[b] = size
484
485  def FindTransfers(self):
486    self.transfers = []
487    empty = RangeSet()
488    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
489      if tgt_fn == "__ZERO":
490        # the special "__ZERO" domain is all the blocks not contained
491        # in any file and that are filled with zeros.  We have a
492        # special transfer style for zero blocks.
493        src_ranges = self.src.file_map.get("__ZERO", empty)
494        Transfer(tgt_fn, None, tgt_ranges, src_ranges, "zero", self.transfers)
495        continue
496
497      elif tgt_fn in self.src.file_map:
498        # Look for an exact pathname match in the source.
499        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
500                 "diff", self.transfers)
501        continue
502
503      b = os.path.basename(tgt_fn)
504      if b in self.src_basenames:
505        # Look for an exact basename match in the source.
506        src_fn = self.src_basenames[b]
507        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
508                 "diff", self.transfers)
509        continue
510
511      b = re.sub("[0-9]+", "#", b)
512      if b in self.src_numpatterns:
513        # Look for a 'number pattern' match (a basename match after
514        # all runs of digits are replaced by "#").  (This is useful
515        # for .so files that contain version numbers in the filename
516        # that get bumped.)
517        src_fn = self.src_numpatterns[b]
518        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
519                 "diff", self.transfers)
520        continue
521
522      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
523
524  def AbbreviateSourceNames(self):
525    self.src_basenames = {}
526    self.src_numpatterns = {}
527
528    for k in self.src.file_map.keys():
529      b = os.path.basename(k)
530      self.src_basenames[b] = k
531      b = re.sub("[0-9]+", "#", b)
532      self.src_numpatterns[b] = k
533
534  @staticmethod
535  def AssertPartition(total, seq):
536    """Assert that all the RangeSets in 'seq' form a partition of the
537    'total' RangeSet (ie, they are nonintersecting and their union
538    equals 'total')."""
539    so_far = RangeSet()
540    for i in seq:
541      assert not so_far.overlaps(i)
542      so_far = so_far.union(i)
543    assert so_far == total
544