blockimgdiff.py revision fc44a515d46e6f4d5eaa0d32659b1cf3b9492305
1from __future__ import print_function 2 3from collections import deque, OrderedDict 4from hashlib import sha1 5import itertools 6import multiprocessing 7import os 8import pprint 9import re 10import subprocess 11import sys 12import threading 13import tempfile 14 15from rangelib import * 16 17def compute_patch(src, tgt, imgdiff=False): 18 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 19 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 20 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 21 os.close(patchfd) 22 23 try: 24 with os.fdopen(srcfd, "wb") as f_src: 25 for p in src: 26 f_src.write(p) 27 28 with os.fdopen(tgtfd, "wb") as f_tgt: 29 for p in tgt: 30 f_tgt.write(p) 31 try: 32 os.unlink(patchfile) 33 except OSError: 34 pass 35 if imgdiff: 36 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 37 stdout=open("/dev/null", "a"), 38 stderr=subprocess.STDOUT) 39 else: 40 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 41 42 if p: 43 raise ValueError("diff failed: " + str(p)) 44 45 with open(patchfile, "rb") as f: 46 return f.read() 47 finally: 48 try: 49 os.unlink(srcfile) 50 os.unlink(tgtfile) 51 os.unlink(patchfile) 52 except OSError: 53 pass 54 55class EmptyImage(object): 56 """A zero-length image.""" 57 blocksize = 4096 58 care_map = RangeSet() 59 total_blocks = 0 60 file_map = {} 61 def ReadRangeSet(self, ranges): 62 return () 63 64class Transfer(object): 65 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 66 self.tgt_name = tgt_name 67 self.src_name = src_name 68 self.tgt_ranges = tgt_ranges 69 self.src_ranges = src_ranges 70 self.style = style 71 self.intact = (getattr(tgt_ranges, "monotonic", False) and 72 getattr(src_ranges, "monotonic", False)) 73 self.goes_before = {} 74 self.goes_after = {} 75 76 self.id = len(by_id) 77 by_id.append(self) 78 79 def __str__(self): 80 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 81 " to " + str(self.tgt_ranges) + ">") 82 83 84# BlockImageDiff works on two image objects. An image object is 85# anything that provides the following attributes: 86# 87# blocksize: the size in bytes of a block, currently must be 4096. 88# 89# total_blocks: the total size of the partition/image, in blocks. 90# 91# care_map: a RangeSet containing which blocks (in the range [0, 92# total_blocks) we actually care about; i.e. which blocks contain 93# data. 94# 95# file_map: a dict that partitions the blocks contained in care_map 96# into smaller domains that are useful for doing diffs on. 97# (Typically a domain is a file, and the key in file_map is the 98# pathname.) 99# 100# ReadRangeSet(): a function that takes a RangeSet and returns the 101# data contained in the image blocks of that RangeSet. The data 102# is returned as a list or tuple of strings; concatenating the 103# elements together should produce the requested data. 104# Implementations are free to break up the data into list/tuple 105# elements in any way that is convenient. 106# 107# When creating a BlockImageDiff, the src image may be None, in which 108# case the list of transfers produced will never read from the 109# original image. 110 111class BlockImageDiff(object): 112 def __init__(self, tgt, src=None, threads=None): 113 if threads is None: 114 threads = multiprocessing.cpu_count() // 2 115 if threads == 0: threads = 1 116 self.threads = threads 117 118 self.tgt = tgt 119 if src is None: 120 src = EmptyImage() 121 self.src = src 122 123 # The updater code that installs the patch always uses 4k blocks. 124 assert tgt.blocksize == 4096 125 assert src.blocksize == 4096 126 127 # The range sets in each filemap should comprise a partition of 128 # the care map. 129 self.AssertPartition(src.care_map, src.file_map.values()) 130 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 131 132 def Compute(self, prefix): 133 # When looking for a source file to use as the diff input for a 134 # target file, we try: 135 # 1) an exact path match if available, otherwise 136 # 2) a exact basename match if available, otherwise 137 # 3) a basename match after all runs of digits are replaced by 138 # "#" if available, otherwise 139 # 4) we have no source for this target. 140 self.AbbreviateSourceNames() 141 self.FindTransfers() 142 143 # Find the ordering dependencies among transfers (this is O(n^2) 144 # in the number of transfers). 145 self.GenerateDigraph() 146 # Find a sequence of transfers that satisfies as many ordering 147 # dependencies as possible (heuristically). 148 self.FindVertexSequence() 149 # Fix up the ordering dependencies that the sequence didn't 150 # satisfy. 151 self.RemoveBackwardEdges() 152 # Double-check our work. 153 self.AssertSequenceGood() 154 155 self.ComputePatches(prefix) 156 self.WriteTransfers(prefix) 157 158 def WriteTransfers(self, prefix): 159 out = [] 160 161 out.append("1\n") # format version number 162 total = 0 163 performs_read = False 164 165 for xf in self.transfers: 166 167 # zero [rangeset] 168 # new [rangeset] 169 # bsdiff patchstart patchlen [src rangeset] [tgt rangeset] 170 # imgdiff patchstart patchlen [src rangeset] [tgt rangeset] 171 # move [src rangeset] [tgt rangeset] 172 # erase [rangeset] 173 174 tgt_size = xf.tgt_ranges.size() 175 176 if xf.style == "new": 177 assert xf.tgt_ranges 178 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 179 total += tgt_size 180 elif xf.style == "move": 181 performs_read = True 182 assert xf.tgt_ranges 183 assert xf.src_ranges.size() == tgt_size 184 if xf.src_ranges != xf.tgt_ranges: 185 out.append("%s %s %s\n" % ( 186 xf.style, 187 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 188 total += tgt_size 189 elif xf.style in ("bsdiff", "imgdiff"): 190 performs_read = True 191 assert xf.tgt_ranges 192 assert xf.src_ranges 193 out.append("%s %d %d %s %s\n" % ( 194 xf.style, xf.patch_start, xf.patch_len, 195 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 196 total += tgt_size 197 elif xf.style == "zero": 198 assert xf.tgt_ranges 199 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 200 if to_zero: 201 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 202 total += to_zero.size() 203 else: 204 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 205 206 out.insert(1, str(total) + "\n") 207 208 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 209 if performs_read: 210 # if some of the original data is used, then at the end we'll 211 # erase all the blocks on the partition that don't contain data 212 # in the new image. 213 new_dontcare = all_tgt.subtract(self.tgt.care_map) 214 if new_dontcare: 215 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 216 else: 217 # if nothing is read (ie, this is a full OTA), then we can start 218 # by erasing the entire partition. 219 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),)) 220 221 with open(prefix + ".transfer.list", "wb") as f: 222 for i in out: 223 f.write(i) 224 225 def ComputePatches(self, prefix): 226 print("Reticulating splines...") 227 diff_q = [] 228 patch_num = 0 229 with open(prefix + ".new.dat", "wb") as new_f: 230 for xf in self.transfers: 231 if xf.style == "zero": 232 pass 233 elif xf.style == "new": 234 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 235 new_f.write(piece) 236 elif xf.style == "diff": 237 src = self.src.ReadRangeSet(xf.src_ranges) 238 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 239 240 # We can't compare src and tgt directly because they may have 241 # the same content but be broken up into blocks differently, eg: 242 # 243 # ["he", "llo"] vs ["h", "ello"] 244 # 245 # We want those to compare equal, ideally without having to 246 # actually concatenate the strings (these may be tens of 247 # megabytes). 248 249 src_sha1 = sha1() 250 for p in src: 251 src_sha1.update(p) 252 tgt_sha1 = sha1() 253 tgt_size = 0 254 for p in tgt: 255 tgt_sha1.update(p) 256 tgt_size += len(p) 257 258 if src_sha1.digest() == tgt_sha1.digest(): 259 # These are identical; we don't need to generate a patch, 260 # just issue copy commands on the device. 261 xf.style = "move" 262 else: 263 # For files in zip format (eg, APKs, JARs, etc.) we would 264 # like to use imgdiff -z if possible (because it usually 265 # produces significantly smaller patches than bsdiff). 266 # This is permissible if: 267 # 268 # - the source and target files are monotonic (ie, the 269 # data is stored with blocks in increasing order), and 270 # - we haven't removed any blocks from the source set. 271 # 272 # If these conditions are satisfied then appending all the 273 # blocks in the set together in order will produce a valid 274 # zip file (plus possibly extra zeros in the last block), 275 # which is what imgdiff needs to operate. (imgdiff is 276 # fine with extra zeros at the end of the file.) 277 imgdiff = (xf.intact and 278 xf.tgt_name.split(".")[-1].lower() 279 in ("apk", "jar", "zip")) 280 xf.style = "imgdiff" if imgdiff else "bsdiff" 281 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 282 patch_num += 1 283 284 else: 285 assert False, "unknown style " + xf.style 286 287 if diff_q: 288 if self.threads > 1: 289 print("Computing patches (using %d threads)..." % (self.threads,)) 290 else: 291 print("Computing patches...") 292 diff_q.sort() 293 294 patches = [None] * patch_num 295 296 lock = threading.Lock() 297 def diff_worker(): 298 while True: 299 with lock: 300 if not diff_q: return 301 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 302 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 303 size = len(patch) 304 with lock: 305 patches[patchnum] = (patch, xf) 306 print("%10d %10d (%6.2f%%) %7s %s" % ( 307 size, tgt_size, size * 100.0 / tgt_size, xf.style, 308 xf.tgt_name if xf.tgt_name == xf.src_name else ( 309 xf.tgt_name + " (from " + xf.src_name + ")"))) 310 311 threads = [threading.Thread(target=diff_worker) 312 for i in range(self.threads)] 313 for th in threads: 314 th.start() 315 while threads: 316 threads.pop().join() 317 else: 318 patches = [] 319 320 p = 0 321 with open(prefix + ".patch.dat", "wb") as patch_f: 322 for patch, xf in patches: 323 xf.patch_start = p 324 xf.patch_len = len(patch) 325 patch_f.write(patch) 326 p += len(patch) 327 328 def AssertSequenceGood(self): 329 # Simulate the sequences of transfers we will output, and check that: 330 # - we never read a block after writing it, and 331 # - we write every block we care about exactly once. 332 333 # Start with no blocks having been touched yet. 334 touched = RangeSet() 335 336 # Imagine processing the transfers in order. 337 for xf in self.transfers: 338 # Check that the input blocks for this transfer haven't yet been touched. 339 assert not touched.overlaps(xf.src_ranges) 340 # Check that the output blocks for this transfer haven't yet been touched. 341 assert not touched.overlaps(xf.tgt_ranges) 342 # Touch all the blocks written by this transfer. 343 touched = touched.union(xf.tgt_ranges) 344 345 # Check that we've written every target block. 346 assert touched == self.tgt.care_map 347 348 def RemoveBackwardEdges(self): 349 print("Removing backward edges...") 350 in_order = 0 351 out_of_order = 0 352 lost_source = 0 353 354 for xf in self.transfers: 355 io = 0 356 ooo = 0 357 lost = 0 358 size = xf.src_ranges.size() 359 for u in xf.goes_before: 360 # xf should go before u 361 if xf.order < u.order: 362 # it does, hurray! 363 io += 1 364 else: 365 # it doesn't, boo. trim the blocks that u writes from xf's 366 # source, so that xf can go after u. 367 ooo += 1 368 assert xf.src_ranges.overlaps(u.tgt_ranges) 369 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 370 xf.intact = False 371 372 if xf.style == "diff" and not xf.src_ranges: 373 # nothing left to diff from; treat as new data 374 xf.style = "new" 375 376 lost = size - xf.src_ranges.size() 377 lost_source += lost 378 in_order += io 379 out_of_order += ooo 380 381 print((" %d/%d dependencies (%.2f%%) were violated; " 382 "%d source blocks removed.") % 383 (out_of_order, in_order + out_of_order, 384 (out_of_order * 100.0 / (in_order + out_of_order)) 385 if (in_order + out_of_order) else 0.0, 386 lost_source)) 387 388 def FindVertexSequence(self): 389 print("Finding vertex sequence...") 390 391 # This is based on "A Fast & Effective Heuristic for the Feedback 392 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 393 # it as starting with the digraph G and moving all the vertices to 394 # be on a horizontal line in some order, trying to minimize the 395 # number of edges that end up pointing to the left. Left-pointing 396 # edges will get removed to turn the digraph into a DAG. In this 397 # case each edge has a weight which is the number of source blocks 398 # we'll lose if that edge is removed; we try to minimize the total 399 # weight rather than just the number of edges. 400 401 # Make a copy of the edge set; this copy will get destroyed by the 402 # algorithm. 403 for xf in self.transfers: 404 xf.incoming = xf.goes_after.copy() 405 xf.outgoing = xf.goes_before.copy() 406 407 # We use an OrderedDict instead of just a set so that the output 408 # is repeatable; otherwise it would depend on the hash values of 409 # the transfer objects. 410 G = OrderedDict() 411 for xf in self.transfers: 412 G[xf] = None 413 s1 = deque() # the left side of the sequence, built from left to right 414 s2 = deque() # the right side of the sequence, built from right to left 415 416 while G: 417 418 # Put all sinks at the end of the sequence. 419 while True: 420 sinks = [u for u in G if not u.outgoing] 421 if not sinks: break 422 for u in sinks: 423 s2.appendleft(u) 424 del G[u] 425 for iu in u.incoming: 426 del iu.outgoing[u] 427 428 # Put all the sources at the beginning of the sequence. 429 while True: 430 sources = [u for u in G if not u.incoming] 431 if not sources: break 432 for u in sources: 433 s1.append(u) 434 del G[u] 435 for iu in u.outgoing: 436 del iu.incoming[u] 437 438 if not G: break 439 440 # Find the "best" vertex to put next. "Best" is the one that 441 # maximizes the net difference in source blocks saved we get by 442 # pretending it's a source rather than a sink. 443 444 max_d = None 445 best_u = None 446 for u in G: 447 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 448 if best_u is None or d > max_d: 449 max_d = d 450 best_u = u 451 452 u = best_u 453 s1.append(u) 454 del G[u] 455 for iu in u.outgoing: 456 del iu.incoming[u] 457 for iu in u.incoming: 458 del iu.outgoing[u] 459 460 # Now record the sequence in the 'order' field of each transfer, 461 # and by rearranging self.transfers to be in the chosen sequence. 462 463 new_transfers = [] 464 for x in itertools.chain(s1, s2): 465 x.order = len(new_transfers) 466 new_transfers.append(x) 467 del x.incoming 468 del x.outgoing 469 470 self.transfers = new_transfers 471 472 def GenerateDigraph(self): 473 print("Generating digraph...") 474 for a in self.transfers: 475 for b in self.transfers: 476 if a is b: continue 477 478 # If the blocks written by A are read by B, then B needs to go before A. 479 i = a.tgt_ranges.intersect(b.src_ranges) 480 if i: 481 size = i.size() 482 b.goes_before[a] = size 483 a.goes_after[b] = size 484 485 def FindTransfers(self): 486 self.transfers = [] 487 empty = RangeSet() 488 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 489 if tgt_fn == "__ZERO": 490 # the special "__ZERO" domain is all the blocks not contained 491 # in any file and that are filled with zeros. We have a 492 # special transfer style for zero blocks. 493 src_ranges = self.src.file_map.get("__ZERO", empty) 494 Transfer(tgt_fn, None, tgt_ranges, src_ranges, "zero", self.transfers) 495 continue 496 497 elif tgt_fn in self.src.file_map: 498 # Look for an exact pathname match in the source. 499 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 500 "diff", self.transfers) 501 continue 502 503 b = os.path.basename(tgt_fn) 504 if b in self.src_basenames: 505 # Look for an exact basename match in the source. 506 src_fn = self.src_basenames[b] 507 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 508 "diff", self.transfers) 509 continue 510 511 b = re.sub("[0-9]+", "#", b) 512 if b in self.src_numpatterns: 513 # Look for a 'number pattern' match (a basename match after 514 # all runs of digits are replaced by "#"). (This is useful 515 # for .so files that contain version numbers in the filename 516 # that get bumped.) 517 src_fn = self.src_numpatterns[b] 518 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 519 "diff", self.transfers) 520 continue 521 522 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 523 524 def AbbreviateSourceNames(self): 525 self.src_basenames = {} 526 self.src_numpatterns = {} 527 528 for k in self.src.file_map.keys(): 529 b = os.path.basename(k) 530 self.src_basenames[b] = k 531 b = re.sub("[0-9]+", "#", b) 532 self.src_numpatterns[b] = k 533 534 @staticmethod 535 def AssertPartition(total, seq): 536 """Assert that all the RangeSets in 'seq' form a partition of the 537 'total' RangeSet (ie, they are nonintersecting and their union 538 equals 'total').""" 539 so_far = RangeSet() 540 for i in seq: 541 assert not so_far.overlaps(i) 542 so_far = so_far.union(i) 543 assert so_far == total 544