blockimgdiff.py revision c50f8359e67e312867b066141f97ab0eb7dff137
1from __future__ import print_function
2
3from collections import deque, OrderedDict
4from hashlib import sha1
5import itertools
6import multiprocessing
7import os
8import pprint
9import re
10import subprocess
11import sys
12import threading
13import tempfile
14
15from rangelib import *
16
17__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
18
19def compute_patch(src, tgt, imgdiff=False):
20  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
21  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
22  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
23  os.close(patchfd)
24
25  try:
26    with os.fdopen(srcfd, "wb") as f_src:
27      for p in src:
28        f_src.write(p)
29
30    with os.fdopen(tgtfd, "wb") as f_tgt:
31      for p in tgt:
32        f_tgt.write(p)
33    try:
34      os.unlink(patchfile)
35    except OSError:
36      pass
37    if imgdiff:
38      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
39                          stdout=open("/dev/null", "a"),
40                          stderr=subprocess.STDOUT)
41    else:
42      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
43
44    if p:
45      raise ValueError("diff failed: " + str(p))
46
47    with open(patchfile, "rb") as f:
48      return f.read()
49  finally:
50    try:
51      os.unlink(srcfile)
52      os.unlink(tgtfile)
53      os.unlink(patchfile)
54    except OSError:
55      pass
56
57class EmptyImage(object):
58  """A zero-length image."""
59  blocksize = 4096
60  care_map = RangeSet()
61  total_blocks = 0
62  file_map = {}
63  def ReadRangeSet(self, ranges):
64    return ()
65  def TotalSha1(self):
66    return sha1().hexdigest()
67
68
69class DataImage(object):
70  """An image wrapped around a single string of data."""
71
72  def __init__(self, data, trim=False, pad=False):
73    self.data = data
74    self.blocksize = 4096
75
76    assert not (trim and pad)
77
78    partial = len(self.data) % self.blocksize
79    if partial > 0:
80      if trim:
81        self.data = self.data[:-partial]
82      elif pad:
83        self.data += '\0' * (self.blocksize - partial)
84      else:
85        raise ValueError(("data for DataImage must be multiple of %d bytes "
86                          "unless trim or pad is specified") %
87                         (self.blocksize,))
88
89    assert len(self.data) % self.blocksize == 0
90
91    self.total_blocks = len(self.data) / self.blocksize
92    self.care_map = RangeSet(data=(0, self.total_blocks))
93
94    zero_blocks = []
95    nonzero_blocks = []
96    reference = '\0' * self.blocksize
97
98    for i in range(self.total_blocks):
99      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
100      if d == reference:
101        zero_blocks.append(i)
102        zero_blocks.append(i+1)
103      else:
104        nonzero_blocks.append(i)
105        nonzero_blocks.append(i+1)
106
107    self.file_map = {"__ZERO": RangeSet(zero_blocks),
108                     "__NONZERO": RangeSet(nonzero_blocks)}
109
110  def ReadRangeSet(self, ranges):
111    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
112
113  def TotalSha1(self):
114    if not hasattr(self, "sha1"):
115      self.sha1 = sha1(self.data).hexdigest()
116    return self.sha1
117
118
119class Transfer(object):
120  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
121    self.tgt_name = tgt_name
122    self.src_name = src_name
123    self.tgt_ranges = tgt_ranges
124    self.src_ranges = src_ranges
125    self.style = style
126    self.intact = (getattr(tgt_ranges, "monotonic", False) and
127                   getattr(src_ranges, "monotonic", False))
128    self.goes_before = {}
129    self.goes_after = {}
130
131    self.id = len(by_id)
132    by_id.append(self)
133
134  def __str__(self):
135    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
136            " to " + str(self.tgt_ranges) + ">")
137
138
139# BlockImageDiff works on two image objects.  An image object is
140# anything that provides the following attributes:
141#
142#    blocksize: the size in bytes of a block, currently must be 4096.
143#
144#    total_blocks: the total size of the partition/image, in blocks.
145#
146#    care_map: a RangeSet containing which blocks (in the range [0,
147#      total_blocks) we actually care about; i.e. which blocks contain
148#      data.
149#
150#    file_map: a dict that partitions the blocks contained in care_map
151#      into smaller domains that are useful for doing diffs on.
152#      (Typically a domain is a file, and the key in file_map is the
153#      pathname.)
154#
155#    ReadRangeSet(): a function that takes a RangeSet and returns the
156#      data contained in the image blocks of that RangeSet.  The data
157#      is returned as a list or tuple of strings; concatenating the
158#      elements together should produce the requested data.
159#      Implementations are free to break up the data into list/tuple
160#      elements in any way that is convenient.
161#
162#    TotalSha1(): a function that returns (as a hex string) the SHA-1
163#      hash of all the data in the image (ie, all the blocks in the
164#      care_map)
165#
166# When creating a BlockImageDiff, the src image may be None, in which
167# case the list of transfers produced will never read from the
168# original image.
169
170class BlockImageDiff(object):
171  def __init__(self, tgt, src=None, threads=None):
172    if threads is None:
173      threads = multiprocessing.cpu_count() // 2
174      if threads == 0: threads = 1
175    self.threads = threads
176
177    self.tgt = tgt
178    if src is None:
179      src = EmptyImage()
180    self.src = src
181
182    # The updater code that installs the patch always uses 4k blocks.
183    assert tgt.blocksize == 4096
184    assert src.blocksize == 4096
185
186    # The range sets in each filemap should comprise a partition of
187    # the care map.
188    self.AssertPartition(src.care_map, src.file_map.values())
189    self.AssertPartition(tgt.care_map, tgt.file_map.values())
190
191  def Compute(self, prefix):
192    # When looking for a source file to use as the diff input for a
193    # target file, we try:
194    #   1) an exact path match if available, otherwise
195    #   2) a exact basename match if available, otherwise
196    #   3) a basename match after all runs of digits are replaced by
197    #      "#" if available, otherwise
198    #   4) we have no source for this target.
199    self.AbbreviateSourceNames()
200    self.FindTransfers()
201
202    # Find the ordering dependencies among transfers (this is O(n^2)
203    # in the number of transfers).
204    self.GenerateDigraph()
205    # Find a sequence of transfers that satisfies as many ordering
206    # dependencies as possible (heuristically).
207    self.FindVertexSequence()
208    # Fix up the ordering dependencies that the sequence didn't
209    # satisfy.
210    self.RemoveBackwardEdges()
211    # Double-check our work.
212    self.AssertSequenceGood()
213
214    self.ComputePatches(prefix)
215    self.WriteTransfers(prefix)
216
217  def WriteTransfers(self, prefix):
218    out = []
219
220    out.append("1\n")   # format version number
221    total = 0
222    performs_read = False
223
224    for xf in self.transfers:
225
226      # zero [rangeset]
227      # new [rangeset]
228      # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
229      # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
230      # move [src rangeset] [tgt rangeset]
231      # erase [rangeset]
232
233      tgt_size = xf.tgt_ranges.size()
234
235      if xf.style == "new":
236        assert xf.tgt_ranges
237        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
238        total += tgt_size
239      elif xf.style == "move":
240        performs_read = True
241        assert xf.tgt_ranges
242        assert xf.src_ranges.size() == tgt_size
243        if xf.src_ranges != xf.tgt_ranges:
244          out.append("%s %s %s\n" % (
245              xf.style,
246              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
247          total += tgt_size
248      elif xf.style in ("bsdiff", "imgdiff"):
249        performs_read = True
250        assert xf.tgt_ranges
251        assert xf.src_ranges
252        out.append("%s %d %d %s %s\n" % (
253            xf.style, xf.patch_start, xf.patch_len,
254            xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
255        total += tgt_size
256      elif xf.style == "zero":
257        assert xf.tgt_ranges
258        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
259        if to_zero:
260          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
261          total += to_zero.size()
262      else:
263        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
264
265    out.insert(1, str(total) + "\n")
266
267    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
268    if performs_read:
269      # if some of the original data is used, then at the end we'll
270      # erase all the blocks on the partition that don't contain data
271      # in the new image.
272      new_dontcare = all_tgt.subtract(self.tgt.care_map)
273      if new_dontcare:
274        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
275    else:
276      # if nothing is read (ie, this is a full OTA), then we can start
277      # by erasing the entire partition.
278      out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
279
280    with open(prefix + ".transfer.list", "wb") as f:
281      for i in out:
282        f.write(i)
283
284  def ComputePatches(self, prefix):
285    print("Reticulating splines...")
286    diff_q = []
287    patch_num = 0
288    with open(prefix + ".new.dat", "wb") as new_f:
289      for xf in self.transfers:
290        if xf.style == "zero":
291          pass
292        elif xf.style == "new":
293          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
294            new_f.write(piece)
295        elif xf.style == "diff":
296          src = self.src.ReadRangeSet(xf.src_ranges)
297          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
298
299          # We can't compare src and tgt directly because they may have
300          # the same content but be broken up into blocks differently, eg:
301          #
302          #    ["he", "llo"]  vs  ["h", "ello"]
303          #
304          # We want those to compare equal, ideally without having to
305          # actually concatenate the strings (these may be tens of
306          # megabytes).
307
308          src_sha1 = sha1()
309          for p in src:
310            src_sha1.update(p)
311          tgt_sha1 = sha1()
312          tgt_size = 0
313          for p in tgt:
314            tgt_sha1.update(p)
315            tgt_size += len(p)
316
317          if src_sha1.digest() == tgt_sha1.digest():
318            # These are identical; we don't need to generate a patch,
319            # just issue copy commands on the device.
320            xf.style = "move"
321          else:
322            # For files in zip format (eg, APKs, JARs, etc.) we would
323            # like to use imgdiff -z if possible (because it usually
324            # produces significantly smaller patches than bsdiff).
325            # This is permissible if:
326            #
327            #  - the source and target files are monotonic (ie, the
328            #    data is stored with blocks in increasing order), and
329            #  - we haven't removed any blocks from the source set.
330            #
331            # If these conditions are satisfied then appending all the
332            # blocks in the set together in order will produce a valid
333            # zip file (plus possibly extra zeros in the last block),
334            # which is what imgdiff needs to operate.  (imgdiff is
335            # fine with extra zeros at the end of the file.)
336            imgdiff = (xf.intact and
337                       xf.tgt_name.split(".")[-1].lower()
338                       in ("apk", "jar", "zip"))
339            xf.style = "imgdiff" if imgdiff else "bsdiff"
340            diff_q.append((tgt_size, src, tgt, xf, patch_num))
341            patch_num += 1
342
343        else:
344          assert False, "unknown style " + xf.style
345
346    if diff_q:
347      if self.threads > 1:
348        print("Computing patches (using %d threads)..." % (self.threads,))
349      else:
350        print("Computing patches...")
351      diff_q.sort()
352
353      patches = [None] * patch_num
354
355      lock = threading.Lock()
356      def diff_worker():
357        while True:
358          with lock:
359            if not diff_q: return
360            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
361          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
362          size = len(patch)
363          with lock:
364            patches[patchnum] = (patch, xf)
365            print("%10d %10d (%6.2f%%) %7s %s" % (
366                size, tgt_size, size * 100.0 / tgt_size, xf.style,
367                xf.tgt_name if xf.tgt_name == xf.src_name else (
368                    xf.tgt_name + " (from " + xf.src_name + ")")))
369
370      threads = [threading.Thread(target=diff_worker)
371                 for i in range(self.threads)]
372      for th in threads:
373        th.start()
374      while threads:
375        threads.pop().join()
376    else:
377      patches = []
378
379    p = 0
380    with open(prefix + ".patch.dat", "wb") as patch_f:
381      for patch, xf in patches:
382        xf.patch_start = p
383        xf.patch_len = len(patch)
384        patch_f.write(patch)
385        p += len(patch)
386
387  def AssertSequenceGood(self):
388    # Simulate the sequences of transfers we will output, and check that:
389    # - we never read a block after writing it, and
390    # - we write every block we care about exactly once.
391
392    # Start with no blocks having been touched yet.
393    touched = RangeSet()
394
395    # Imagine processing the transfers in order.
396    for xf in self.transfers:
397      # Check that the input blocks for this transfer haven't yet been touched.
398      assert not touched.overlaps(xf.src_ranges)
399      # Check that the output blocks for this transfer haven't yet been touched.
400      assert not touched.overlaps(xf.tgt_ranges)
401      # Touch all the blocks written by this transfer.
402      touched = touched.union(xf.tgt_ranges)
403
404    # Check that we've written every target block.
405    assert touched == self.tgt.care_map
406
407  def RemoveBackwardEdges(self):
408    print("Removing backward edges...")
409    in_order = 0
410    out_of_order = 0
411    lost_source = 0
412
413    for xf in self.transfers:
414      io = 0
415      ooo = 0
416      lost = 0
417      size = xf.src_ranges.size()
418      for u in xf.goes_before:
419        # xf should go before u
420        if xf.order < u.order:
421          # it does, hurray!
422          io += 1
423        else:
424          # it doesn't, boo.  trim the blocks that u writes from xf's
425          # source, so that xf can go after u.
426          ooo += 1
427          assert xf.src_ranges.overlaps(u.tgt_ranges)
428          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
429          xf.intact = False
430
431      if xf.style == "diff" and not xf.src_ranges:
432        # nothing left to diff from; treat as new data
433        xf.style = "new"
434
435      lost = size - xf.src_ranges.size()
436      lost_source += lost
437      in_order += io
438      out_of_order += ooo
439
440    print(("  %d/%d dependencies (%.2f%%) were violated; "
441           "%d source blocks removed.") %
442          (out_of_order, in_order + out_of_order,
443           (out_of_order * 100.0 / (in_order + out_of_order))
444           if (in_order + out_of_order) else 0.0,
445           lost_source))
446
447  def FindVertexSequence(self):
448    print("Finding vertex sequence...")
449
450    # This is based on "A Fast & Effective Heuristic for the Feedback
451    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
452    # it as starting with the digraph G and moving all the vertices to
453    # be on a horizontal line in some order, trying to minimize the
454    # number of edges that end up pointing to the left.  Left-pointing
455    # edges will get removed to turn the digraph into a DAG.  In this
456    # case each edge has a weight which is the number of source blocks
457    # we'll lose if that edge is removed; we try to minimize the total
458    # weight rather than just the number of edges.
459
460    # Make a copy of the edge set; this copy will get destroyed by the
461    # algorithm.
462    for xf in self.transfers:
463      xf.incoming = xf.goes_after.copy()
464      xf.outgoing = xf.goes_before.copy()
465
466    # We use an OrderedDict instead of just a set so that the output
467    # is repeatable; otherwise it would depend on the hash values of
468    # the transfer objects.
469    G = OrderedDict()
470    for xf in self.transfers:
471      G[xf] = None
472    s1 = deque()  # the left side of the sequence, built from left to right
473    s2 = deque()  # the right side of the sequence, built from right to left
474
475    while G:
476
477      # Put all sinks at the end of the sequence.
478      while True:
479        sinks = [u for u in G if not u.outgoing]
480        if not sinks: break
481        for u in sinks:
482          s2.appendleft(u)
483          del G[u]
484          for iu in u.incoming:
485            del iu.outgoing[u]
486
487      # Put all the sources at the beginning of the sequence.
488      while True:
489        sources = [u for u in G if not u.incoming]
490        if not sources: break
491        for u in sources:
492          s1.append(u)
493          del G[u]
494          for iu in u.outgoing:
495            del iu.incoming[u]
496
497      if not G: break
498
499      # Find the "best" vertex to put next.  "Best" is the one that
500      # maximizes the net difference in source blocks saved we get by
501      # pretending it's a source rather than a sink.
502
503      max_d = None
504      best_u = None
505      for u in G:
506        d = sum(u.outgoing.values()) - sum(u.incoming.values())
507        if best_u is None or d > max_d:
508          max_d = d
509          best_u = u
510
511      u = best_u
512      s1.append(u)
513      del G[u]
514      for iu in u.outgoing:
515        del iu.incoming[u]
516      for iu in u.incoming:
517        del iu.outgoing[u]
518
519    # Now record the sequence in the 'order' field of each transfer,
520    # and by rearranging self.transfers to be in the chosen sequence.
521
522    new_transfers = []
523    for x in itertools.chain(s1, s2):
524      x.order = len(new_transfers)
525      new_transfers.append(x)
526      del x.incoming
527      del x.outgoing
528
529    self.transfers = new_transfers
530
531  def GenerateDigraph(self):
532    print("Generating digraph...")
533    for a in self.transfers:
534      for b in self.transfers:
535        if a is b: continue
536
537        # If the blocks written by A are read by B, then B needs to go before A.
538        i = a.tgt_ranges.intersect(b.src_ranges)
539        if i:
540          if b.src_name == "__ZERO":
541            # the cost of removing source blocks for the __ZERO domain
542            # is (nearly) zero.
543            size = 0
544          else:
545            size = i.size()
546          b.goes_before[a] = size
547          a.goes_after[b] = size
548
549  def FindTransfers(self):
550    self.transfers = []
551    empty = RangeSet()
552    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
553      if tgt_fn == "__ZERO":
554        # the special "__ZERO" domain is all the blocks not contained
555        # in any file and that are filled with zeros.  We have a
556        # special transfer style for zero blocks.
557        src_ranges = self.src.file_map.get("__ZERO", empty)
558        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
559                 "zero", self.transfers)
560        continue
561
562      elif tgt_fn in self.src.file_map:
563        # Look for an exact pathname match in the source.
564        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
565                 "diff", self.transfers)
566        continue
567
568      b = os.path.basename(tgt_fn)
569      if b in self.src_basenames:
570        # Look for an exact basename match in the source.
571        src_fn = self.src_basenames[b]
572        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
573                 "diff", self.transfers)
574        continue
575
576      b = re.sub("[0-9]+", "#", b)
577      if b in self.src_numpatterns:
578        # Look for a 'number pattern' match (a basename match after
579        # all runs of digits are replaced by "#").  (This is useful
580        # for .so files that contain version numbers in the filename
581        # that get bumped.)
582        src_fn = self.src_numpatterns[b]
583        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
584                 "diff", self.transfers)
585        continue
586
587      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
588
589  def AbbreviateSourceNames(self):
590    self.src_basenames = {}
591    self.src_numpatterns = {}
592
593    for k in self.src.file_map.keys():
594      b = os.path.basename(k)
595      self.src_basenames[b] = k
596      b = re.sub("[0-9]+", "#", b)
597      self.src_numpatterns[b] = k
598
599  @staticmethod
600  def AssertPartition(total, seq):
601    """Assert that all the RangeSets in 'seq' form a partition of the
602    'total' RangeSet (ie, they are nonintersecting and their union
603    equals 'total')."""
604    so_far = RangeSet()
605    for i in seq:
606      assert not so_far.overlaps(i)
607      so_far = so_far.union(i)
608    assert so_far == total
609