blockimgdiff.py revision c50f8359e67e312867b066141f97ab0eb7dff137
1from __future__ import print_function 2 3from collections import deque, OrderedDict 4from hashlib import sha1 5import itertools 6import multiprocessing 7import os 8import pprint 9import re 10import subprocess 11import sys 12import threading 13import tempfile 14 15from rangelib import * 16 17__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 18 19def compute_patch(src, tgt, imgdiff=False): 20 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 21 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 22 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 23 os.close(patchfd) 24 25 try: 26 with os.fdopen(srcfd, "wb") as f_src: 27 for p in src: 28 f_src.write(p) 29 30 with os.fdopen(tgtfd, "wb") as f_tgt: 31 for p in tgt: 32 f_tgt.write(p) 33 try: 34 os.unlink(patchfile) 35 except OSError: 36 pass 37 if imgdiff: 38 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 39 stdout=open("/dev/null", "a"), 40 stderr=subprocess.STDOUT) 41 else: 42 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 43 44 if p: 45 raise ValueError("diff failed: " + str(p)) 46 47 with open(patchfile, "rb") as f: 48 return f.read() 49 finally: 50 try: 51 os.unlink(srcfile) 52 os.unlink(tgtfile) 53 os.unlink(patchfile) 54 except OSError: 55 pass 56 57class EmptyImage(object): 58 """A zero-length image.""" 59 blocksize = 4096 60 care_map = RangeSet() 61 total_blocks = 0 62 file_map = {} 63 def ReadRangeSet(self, ranges): 64 return () 65 def TotalSha1(self): 66 return sha1().hexdigest() 67 68 69class DataImage(object): 70 """An image wrapped around a single string of data.""" 71 72 def __init__(self, data, trim=False, pad=False): 73 self.data = data 74 self.blocksize = 4096 75 76 assert not (trim and pad) 77 78 partial = len(self.data) % self.blocksize 79 if partial > 0: 80 if trim: 81 self.data = self.data[:-partial] 82 elif pad: 83 self.data += '\0' * (self.blocksize - partial) 84 else: 85 raise ValueError(("data for DataImage must be multiple of %d bytes " 86 "unless trim or pad is specified") % 87 (self.blocksize,)) 88 89 assert len(self.data) % self.blocksize == 0 90 91 self.total_blocks = len(self.data) / self.blocksize 92 self.care_map = RangeSet(data=(0, self.total_blocks)) 93 94 zero_blocks = [] 95 nonzero_blocks = [] 96 reference = '\0' * self.blocksize 97 98 for i in range(self.total_blocks): 99 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 100 if d == reference: 101 zero_blocks.append(i) 102 zero_blocks.append(i+1) 103 else: 104 nonzero_blocks.append(i) 105 nonzero_blocks.append(i+1) 106 107 self.file_map = {"__ZERO": RangeSet(zero_blocks), 108 "__NONZERO": RangeSet(nonzero_blocks)} 109 110 def ReadRangeSet(self, ranges): 111 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 112 113 def TotalSha1(self): 114 if not hasattr(self, "sha1"): 115 self.sha1 = sha1(self.data).hexdigest() 116 return self.sha1 117 118 119class Transfer(object): 120 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 121 self.tgt_name = tgt_name 122 self.src_name = src_name 123 self.tgt_ranges = tgt_ranges 124 self.src_ranges = src_ranges 125 self.style = style 126 self.intact = (getattr(tgt_ranges, "monotonic", False) and 127 getattr(src_ranges, "monotonic", False)) 128 self.goes_before = {} 129 self.goes_after = {} 130 131 self.id = len(by_id) 132 by_id.append(self) 133 134 def __str__(self): 135 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 136 " to " + str(self.tgt_ranges) + ">") 137 138 139# BlockImageDiff works on two image objects. An image object is 140# anything that provides the following attributes: 141# 142# blocksize: the size in bytes of a block, currently must be 4096. 143# 144# total_blocks: the total size of the partition/image, in blocks. 145# 146# care_map: a RangeSet containing which blocks (in the range [0, 147# total_blocks) we actually care about; i.e. which blocks contain 148# data. 149# 150# file_map: a dict that partitions the blocks contained in care_map 151# into smaller domains that are useful for doing diffs on. 152# (Typically a domain is a file, and the key in file_map is the 153# pathname.) 154# 155# ReadRangeSet(): a function that takes a RangeSet and returns the 156# data contained in the image blocks of that RangeSet. The data 157# is returned as a list or tuple of strings; concatenating the 158# elements together should produce the requested data. 159# Implementations are free to break up the data into list/tuple 160# elements in any way that is convenient. 161# 162# TotalSha1(): a function that returns (as a hex string) the SHA-1 163# hash of all the data in the image (ie, all the blocks in the 164# care_map) 165# 166# When creating a BlockImageDiff, the src image may be None, in which 167# case the list of transfers produced will never read from the 168# original image. 169 170class BlockImageDiff(object): 171 def __init__(self, tgt, src=None, threads=None): 172 if threads is None: 173 threads = multiprocessing.cpu_count() // 2 174 if threads == 0: threads = 1 175 self.threads = threads 176 177 self.tgt = tgt 178 if src is None: 179 src = EmptyImage() 180 self.src = src 181 182 # The updater code that installs the patch always uses 4k blocks. 183 assert tgt.blocksize == 4096 184 assert src.blocksize == 4096 185 186 # The range sets in each filemap should comprise a partition of 187 # the care map. 188 self.AssertPartition(src.care_map, src.file_map.values()) 189 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 190 191 def Compute(self, prefix): 192 # When looking for a source file to use as the diff input for a 193 # target file, we try: 194 # 1) an exact path match if available, otherwise 195 # 2) a exact basename match if available, otherwise 196 # 3) a basename match after all runs of digits are replaced by 197 # "#" if available, otherwise 198 # 4) we have no source for this target. 199 self.AbbreviateSourceNames() 200 self.FindTransfers() 201 202 # Find the ordering dependencies among transfers (this is O(n^2) 203 # in the number of transfers). 204 self.GenerateDigraph() 205 # Find a sequence of transfers that satisfies as many ordering 206 # dependencies as possible (heuristically). 207 self.FindVertexSequence() 208 # Fix up the ordering dependencies that the sequence didn't 209 # satisfy. 210 self.RemoveBackwardEdges() 211 # Double-check our work. 212 self.AssertSequenceGood() 213 214 self.ComputePatches(prefix) 215 self.WriteTransfers(prefix) 216 217 def WriteTransfers(self, prefix): 218 out = [] 219 220 out.append("1\n") # format version number 221 total = 0 222 performs_read = False 223 224 for xf in self.transfers: 225 226 # zero [rangeset] 227 # new [rangeset] 228 # bsdiff patchstart patchlen [src rangeset] [tgt rangeset] 229 # imgdiff patchstart patchlen [src rangeset] [tgt rangeset] 230 # move [src rangeset] [tgt rangeset] 231 # erase [rangeset] 232 233 tgt_size = xf.tgt_ranges.size() 234 235 if xf.style == "new": 236 assert xf.tgt_ranges 237 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 238 total += tgt_size 239 elif xf.style == "move": 240 performs_read = True 241 assert xf.tgt_ranges 242 assert xf.src_ranges.size() == tgt_size 243 if xf.src_ranges != xf.tgt_ranges: 244 out.append("%s %s %s\n" % ( 245 xf.style, 246 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 247 total += tgt_size 248 elif xf.style in ("bsdiff", "imgdiff"): 249 performs_read = True 250 assert xf.tgt_ranges 251 assert xf.src_ranges 252 out.append("%s %d %d %s %s\n" % ( 253 xf.style, xf.patch_start, xf.patch_len, 254 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 255 total += tgt_size 256 elif xf.style == "zero": 257 assert xf.tgt_ranges 258 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 259 if to_zero: 260 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 261 total += to_zero.size() 262 else: 263 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 264 265 out.insert(1, str(total) + "\n") 266 267 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 268 if performs_read: 269 # if some of the original data is used, then at the end we'll 270 # erase all the blocks on the partition that don't contain data 271 # in the new image. 272 new_dontcare = all_tgt.subtract(self.tgt.care_map) 273 if new_dontcare: 274 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 275 else: 276 # if nothing is read (ie, this is a full OTA), then we can start 277 # by erasing the entire partition. 278 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),)) 279 280 with open(prefix + ".transfer.list", "wb") as f: 281 for i in out: 282 f.write(i) 283 284 def ComputePatches(self, prefix): 285 print("Reticulating splines...") 286 diff_q = [] 287 patch_num = 0 288 with open(prefix + ".new.dat", "wb") as new_f: 289 for xf in self.transfers: 290 if xf.style == "zero": 291 pass 292 elif xf.style == "new": 293 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 294 new_f.write(piece) 295 elif xf.style == "diff": 296 src = self.src.ReadRangeSet(xf.src_ranges) 297 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 298 299 # We can't compare src and tgt directly because they may have 300 # the same content but be broken up into blocks differently, eg: 301 # 302 # ["he", "llo"] vs ["h", "ello"] 303 # 304 # We want those to compare equal, ideally without having to 305 # actually concatenate the strings (these may be tens of 306 # megabytes). 307 308 src_sha1 = sha1() 309 for p in src: 310 src_sha1.update(p) 311 tgt_sha1 = sha1() 312 tgt_size = 0 313 for p in tgt: 314 tgt_sha1.update(p) 315 tgt_size += len(p) 316 317 if src_sha1.digest() == tgt_sha1.digest(): 318 # These are identical; we don't need to generate a patch, 319 # just issue copy commands on the device. 320 xf.style = "move" 321 else: 322 # For files in zip format (eg, APKs, JARs, etc.) we would 323 # like to use imgdiff -z if possible (because it usually 324 # produces significantly smaller patches than bsdiff). 325 # This is permissible if: 326 # 327 # - the source and target files are monotonic (ie, the 328 # data is stored with blocks in increasing order), and 329 # - we haven't removed any blocks from the source set. 330 # 331 # If these conditions are satisfied then appending all the 332 # blocks in the set together in order will produce a valid 333 # zip file (plus possibly extra zeros in the last block), 334 # which is what imgdiff needs to operate. (imgdiff is 335 # fine with extra zeros at the end of the file.) 336 imgdiff = (xf.intact and 337 xf.tgt_name.split(".")[-1].lower() 338 in ("apk", "jar", "zip")) 339 xf.style = "imgdiff" if imgdiff else "bsdiff" 340 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 341 patch_num += 1 342 343 else: 344 assert False, "unknown style " + xf.style 345 346 if diff_q: 347 if self.threads > 1: 348 print("Computing patches (using %d threads)..." % (self.threads,)) 349 else: 350 print("Computing patches...") 351 diff_q.sort() 352 353 patches = [None] * patch_num 354 355 lock = threading.Lock() 356 def diff_worker(): 357 while True: 358 with lock: 359 if not diff_q: return 360 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 361 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 362 size = len(patch) 363 with lock: 364 patches[patchnum] = (patch, xf) 365 print("%10d %10d (%6.2f%%) %7s %s" % ( 366 size, tgt_size, size * 100.0 / tgt_size, xf.style, 367 xf.tgt_name if xf.tgt_name == xf.src_name else ( 368 xf.tgt_name + " (from " + xf.src_name + ")"))) 369 370 threads = [threading.Thread(target=diff_worker) 371 for i in range(self.threads)] 372 for th in threads: 373 th.start() 374 while threads: 375 threads.pop().join() 376 else: 377 patches = [] 378 379 p = 0 380 with open(prefix + ".patch.dat", "wb") as patch_f: 381 for patch, xf in patches: 382 xf.patch_start = p 383 xf.patch_len = len(patch) 384 patch_f.write(patch) 385 p += len(patch) 386 387 def AssertSequenceGood(self): 388 # Simulate the sequences of transfers we will output, and check that: 389 # - we never read a block after writing it, and 390 # - we write every block we care about exactly once. 391 392 # Start with no blocks having been touched yet. 393 touched = RangeSet() 394 395 # Imagine processing the transfers in order. 396 for xf in self.transfers: 397 # Check that the input blocks for this transfer haven't yet been touched. 398 assert not touched.overlaps(xf.src_ranges) 399 # Check that the output blocks for this transfer haven't yet been touched. 400 assert not touched.overlaps(xf.tgt_ranges) 401 # Touch all the blocks written by this transfer. 402 touched = touched.union(xf.tgt_ranges) 403 404 # Check that we've written every target block. 405 assert touched == self.tgt.care_map 406 407 def RemoveBackwardEdges(self): 408 print("Removing backward edges...") 409 in_order = 0 410 out_of_order = 0 411 lost_source = 0 412 413 for xf in self.transfers: 414 io = 0 415 ooo = 0 416 lost = 0 417 size = xf.src_ranges.size() 418 for u in xf.goes_before: 419 # xf should go before u 420 if xf.order < u.order: 421 # it does, hurray! 422 io += 1 423 else: 424 # it doesn't, boo. trim the blocks that u writes from xf's 425 # source, so that xf can go after u. 426 ooo += 1 427 assert xf.src_ranges.overlaps(u.tgt_ranges) 428 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 429 xf.intact = False 430 431 if xf.style == "diff" and not xf.src_ranges: 432 # nothing left to diff from; treat as new data 433 xf.style = "new" 434 435 lost = size - xf.src_ranges.size() 436 lost_source += lost 437 in_order += io 438 out_of_order += ooo 439 440 print((" %d/%d dependencies (%.2f%%) were violated; " 441 "%d source blocks removed.") % 442 (out_of_order, in_order + out_of_order, 443 (out_of_order * 100.0 / (in_order + out_of_order)) 444 if (in_order + out_of_order) else 0.0, 445 lost_source)) 446 447 def FindVertexSequence(self): 448 print("Finding vertex sequence...") 449 450 # This is based on "A Fast & Effective Heuristic for the Feedback 451 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 452 # it as starting with the digraph G and moving all the vertices to 453 # be on a horizontal line in some order, trying to minimize the 454 # number of edges that end up pointing to the left. Left-pointing 455 # edges will get removed to turn the digraph into a DAG. In this 456 # case each edge has a weight which is the number of source blocks 457 # we'll lose if that edge is removed; we try to minimize the total 458 # weight rather than just the number of edges. 459 460 # Make a copy of the edge set; this copy will get destroyed by the 461 # algorithm. 462 for xf in self.transfers: 463 xf.incoming = xf.goes_after.copy() 464 xf.outgoing = xf.goes_before.copy() 465 466 # We use an OrderedDict instead of just a set so that the output 467 # is repeatable; otherwise it would depend on the hash values of 468 # the transfer objects. 469 G = OrderedDict() 470 for xf in self.transfers: 471 G[xf] = None 472 s1 = deque() # the left side of the sequence, built from left to right 473 s2 = deque() # the right side of the sequence, built from right to left 474 475 while G: 476 477 # Put all sinks at the end of the sequence. 478 while True: 479 sinks = [u for u in G if not u.outgoing] 480 if not sinks: break 481 for u in sinks: 482 s2.appendleft(u) 483 del G[u] 484 for iu in u.incoming: 485 del iu.outgoing[u] 486 487 # Put all the sources at the beginning of the sequence. 488 while True: 489 sources = [u for u in G if not u.incoming] 490 if not sources: break 491 for u in sources: 492 s1.append(u) 493 del G[u] 494 for iu in u.outgoing: 495 del iu.incoming[u] 496 497 if not G: break 498 499 # Find the "best" vertex to put next. "Best" is the one that 500 # maximizes the net difference in source blocks saved we get by 501 # pretending it's a source rather than a sink. 502 503 max_d = None 504 best_u = None 505 for u in G: 506 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 507 if best_u is None or d > max_d: 508 max_d = d 509 best_u = u 510 511 u = best_u 512 s1.append(u) 513 del G[u] 514 for iu in u.outgoing: 515 del iu.incoming[u] 516 for iu in u.incoming: 517 del iu.outgoing[u] 518 519 # Now record the sequence in the 'order' field of each transfer, 520 # and by rearranging self.transfers to be in the chosen sequence. 521 522 new_transfers = [] 523 for x in itertools.chain(s1, s2): 524 x.order = len(new_transfers) 525 new_transfers.append(x) 526 del x.incoming 527 del x.outgoing 528 529 self.transfers = new_transfers 530 531 def GenerateDigraph(self): 532 print("Generating digraph...") 533 for a in self.transfers: 534 for b in self.transfers: 535 if a is b: continue 536 537 # If the blocks written by A are read by B, then B needs to go before A. 538 i = a.tgt_ranges.intersect(b.src_ranges) 539 if i: 540 if b.src_name == "__ZERO": 541 # the cost of removing source blocks for the __ZERO domain 542 # is (nearly) zero. 543 size = 0 544 else: 545 size = i.size() 546 b.goes_before[a] = size 547 a.goes_after[b] = size 548 549 def FindTransfers(self): 550 self.transfers = [] 551 empty = RangeSet() 552 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 553 if tgt_fn == "__ZERO": 554 # the special "__ZERO" domain is all the blocks not contained 555 # in any file and that are filled with zeros. We have a 556 # special transfer style for zero blocks. 557 src_ranges = self.src.file_map.get("__ZERO", empty) 558 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 559 "zero", self.transfers) 560 continue 561 562 elif tgt_fn in self.src.file_map: 563 # Look for an exact pathname match in the source. 564 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 565 "diff", self.transfers) 566 continue 567 568 b = os.path.basename(tgt_fn) 569 if b in self.src_basenames: 570 # Look for an exact basename match in the source. 571 src_fn = self.src_basenames[b] 572 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 573 "diff", self.transfers) 574 continue 575 576 b = re.sub("[0-9]+", "#", b) 577 if b in self.src_numpatterns: 578 # Look for a 'number pattern' match (a basename match after 579 # all runs of digits are replaced by "#"). (This is useful 580 # for .so files that contain version numbers in the filename 581 # that get bumped.) 582 src_fn = self.src_numpatterns[b] 583 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 584 "diff", self.transfers) 585 continue 586 587 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 588 589 def AbbreviateSourceNames(self): 590 self.src_basenames = {} 591 self.src_numpatterns = {} 592 593 for k in self.src.file_map.keys(): 594 b = os.path.basename(k) 595 self.src_basenames[b] = k 596 b = re.sub("[0-9]+", "#", b) 597 self.src_numpatterns[b] = k 598 599 @staticmethod 600 def AssertPartition(total, seq): 601 """Assert that all the RangeSets in 'seq' form a partition of the 602 'total' RangeSet (ie, they are nonintersecting and their union 603 equals 'total').""" 604 so_far = RangeSet() 605 for i in seq: 606 assert not so_far.overlaps(i) 607 so_far = so_far.union(i) 608 assert so_far == total 609