blockimgdiff.py revision 28f6f9c3deb4d23fd627d15631d682a5cfa989fc
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import common 20import heapq 21import itertools 22import multiprocessing 23import os 24import re 25import subprocess 26import threading 27import tempfile 28 29from rangelib import RangeSet 30 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34 35def compute_patch(src, tgt, imgdiff=False): 36 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 37 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 38 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 39 os.close(patchfd) 40 41 try: 42 with os.fdopen(srcfd, "wb") as f_src: 43 for p in src: 44 f_src.write(p) 45 46 with os.fdopen(tgtfd, "wb") as f_tgt: 47 for p in tgt: 48 f_tgt.write(p) 49 try: 50 os.unlink(patchfile) 51 except OSError: 52 pass 53 if imgdiff: 54 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 55 stdout=open("/dev/null", "a"), 56 stderr=subprocess.STDOUT) 57 else: 58 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 59 60 if p: 61 raise ValueError("diff failed: " + str(p)) 62 63 with open(patchfile, "rb") as f: 64 return f.read() 65 finally: 66 try: 67 os.unlink(srcfile) 68 os.unlink(tgtfile) 69 os.unlink(patchfile) 70 except OSError: 71 pass 72 73 74class Image(object): 75 def ReadRangeSet(self, ranges): 76 raise NotImplementedError 77 78 def TotalSha1(self, include_clobbered_blocks=False): 79 raise NotImplementedError 80 81 82class EmptyImage(Image): 83 """A zero-length image.""" 84 blocksize = 4096 85 care_map = RangeSet() 86 clobbered_blocks = RangeSet() 87 extended = RangeSet() 88 total_blocks = 0 89 file_map = {} 90 def ReadRangeSet(self, ranges): 91 return () 92 def TotalSha1(self, include_clobbered_blocks=False): 93 # EmptyImage always carries empty clobbered_blocks, so 94 # include_clobbered_blocks can be ignored. 95 assert self.clobbered_blocks.size() == 0 96 return sha1().hexdigest() 97 98 99class DataImage(Image): 100 """An image wrapped around a single string of data.""" 101 102 def __init__(self, data, trim=False, pad=False): 103 self.data = data 104 self.blocksize = 4096 105 106 assert not (trim and pad) 107 108 partial = len(self.data) % self.blocksize 109 padded = False 110 if partial > 0: 111 if trim: 112 self.data = self.data[:-partial] 113 elif pad: 114 self.data += '\0' * (self.blocksize - partial) 115 padded = True 116 else: 117 raise ValueError(("data for DataImage must be multiple of %d bytes " 118 "unless trim or pad is specified") % 119 (self.blocksize,)) 120 121 assert len(self.data) % self.blocksize == 0 122 123 self.total_blocks = len(self.data) / self.blocksize 124 self.care_map = RangeSet(data=(0, self.total_blocks)) 125 # When the last block is padded, we always write the whole block even for 126 # incremental OTAs. Because otherwise the last block may get skipped if 127 # unchanged for an incremental, but would fail the post-install 128 # verification if it has non-zero contents in the padding bytes. 129 # Bug: 23828506 130 if padded: 131 self.clobbered_blocks = RangeSet( 132 data=(self.total_blocks-1, self.total_blocks)) 133 else: 134 self.clobbered_blocks = RangeSet() 135 self.extended = RangeSet() 136 137 zero_blocks = [] 138 nonzero_blocks = [] 139 reference = '\0' * self.blocksize 140 141 for i in range(self.total_blocks-1 if padded else self.total_blocks): 142 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 143 if d == reference: 144 zero_blocks.append(i) 145 zero_blocks.append(i+1) 146 else: 147 nonzero_blocks.append(i) 148 nonzero_blocks.append(i+1) 149 150 self.file_map = {"__ZERO": RangeSet(zero_blocks), 151 "__NONZERO": RangeSet(nonzero_blocks)} 152 153 if self.clobbered_blocks: 154 self.file_map["__COPY"] = self.clobbered_blocks 155 156 def ReadRangeSet(self, ranges): 157 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 158 159 def TotalSha1(self, include_clobbered_blocks=False): 160 if not include_clobbered_blocks: 161 ranges = self.care_map.subtract(self.clobbered_blocks) 162 return sha1(self.ReadRangeSet(ranges)).hexdigest() 163 else: 164 return sha1(self.data).hexdigest() 165 166 167class Transfer(object): 168 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 169 self.tgt_name = tgt_name 170 self.src_name = src_name 171 self.tgt_ranges = tgt_ranges 172 self.src_ranges = src_ranges 173 self.style = style 174 self.intact = (getattr(tgt_ranges, "monotonic", False) and 175 getattr(src_ranges, "monotonic", False)) 176 177 # We use OrderedDict rather than dict so that the output is repeatable; 178 # otherwise it would depend on the hash values of the Transfer objects. 179 self.goes_before = OrderedDict() 180 self.goes_after = OrderedDict() 181 182 self.stash_before = [] 183 self.use_stash = [] 184 185 self.id = len(by_id) 186 by_id.append(self) 187 188 def NetStashChange(self): 189 return (sum(sr.size() for (_, sr) in self.stash_before) - 190 sum(sr.size() for (_, sr) in self.use_stash)) 191 192 def ConvertToNew(self): 193 assert self.style != "new" 194 self.use_stash = [] 195 self.style = "new" 196 self.src_ranges = RangeSet() 197 198 def __str__(self): 199 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 200 " to " + str(self.tgt_ranges) + ">") 201 202 203# BlockImageDiff works on two image objects. An image object is 204# anything that provides the following attributes: 205# 206# blocksize: the size in bytes of a block, currently must be 4096. 207# 208# total_blocks: the total size of the partition/image, in blocks. 209# 210# care_map: a RangeSet containing which blocks (in the range [0, 211# total_blocks) we actually care about; i.e. which blocks contain 212# data. 213# 214# file_map: a dict that partitions the blocks contained in care_map 215# into smaller domains that are useful for doing diffs on. 216# (Typically a domain is a file, and the key in file_map is the 217# pathname.) 218# 219# clobbered_blocks: a RangeSet containing which blocks contain data 220# but may be altered by the FS. They need to be excluded when 221# verifying the partition integrity. 222# 223# ReadRangeSet(): a function that takes a RangeSet and returns the 224# data contained in the image blocks of that RangeSet. The data 225# is returned as a list or tuple of strings; concatenating the 226# elements together should produce the requested data. 227# Implementations are free to break up the data into list/tuple 228# elements in any way that is convenient. 229# 230# TotalSha1(): a function that returns (as a hex string) the SHA-1 231# hash of all the data in the image (ie, all the blocks in the 232# care_map minus clobbered_blocks, or including the clobbered 233# blocks if include_clobbered_blocks is True). 234# 235# When creating a BlockImageDiff, the src image may be None, in which 236# case the list of transfers produced will never read from the 237# original image. 238 239class BlockImageDiff(object): 240 def __init__(self, tgt, src=None, threads=None, version=3): 241 if threads is None: 242 threads = multiprocessing.cpu_count() // 2 243 if threads == 0: 244 threads = 1 245 self.threads = threads 246 self.version = version 247 self.transfers = [] 248 self.src_basenames = {} 249 self.src_numpatterns = {} 250 251 assert version in (1, 2, 3) 252 253 self.tgt = tgt 254 if src is None: 255 src = EmptyImage() 256 self.src = src 257 258 # The updater code that installs the patch always uses 4k blocks. 259 assert tgt.blocksize == 4096 260 assert src.blocksize == 4096 261 262 # The range sets in each filemap should comprise a partition of 263 # the care map. 264 self.AssertPartition(src.care_map, src.file_map.values()) 265 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 266 267 def Compute(self, prefix): 268 # When looking for a source file to use as the diff input for a 269 # target file, we try: 270 # 1) an exact path match if available, otherwise 271 # 2) a exact basename match if available, otherwise 272 # 3) a basename match after all runs of digits are replaced by 273 # "#" if available, otherwise 274 # 4) we have no source for this target. 275 self.AbbreviateSourceNames() 276 self.FindTransfers() 277 278 # Find the ordering dependencies among transfers (this is O(n^2) 279 # in the number of transfers). 280 self.GenerateDigraph() 281 # Find a sequence of transfers that satisfies as many ordering 282 # dependencies as possible (heuristically). 283 self.FindVertexSequence() 284 # Fix up the ordering dependencies that the sequence didn't 285 # satisfy. 286 if self.version == 1: 287 self.RemoveBackwardEdges() 288 else: 289 self.ReverseBackwardEdges() 290 self.ImproveVertexSequence() 291 292 # Ensure the runtime stash size is under the limit. 293 if self.version >= 2 and common.OPTIONS.cache_size is not None: 294 self.ReviseStashSize() 295 296 # Double-check our work. 297 self.AssertSequenceGood() 298 299 self.ComputePatches(prefix) 300 self.WriteTransfers(prefix) 301 302 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 303 data = source.ReadRangeSet(ranges) 304 ctx = sha1() 305 306 for p in data: 307 ctx.update(p) 308 309 return ctx.hexdigest() 310 311 def WriteTransfers(self, prefix): 312 out = [] 313 314 total = 0 315 316 stashes = {} 317 stashed_blocks = 0 318 max_stashed_blocks = 0 319 320 free_stash_ids = [] 321 next_stash_id = 0 322 323 for xf in self.transfers: 324 325 if self.version < 2: 326 assert not xf.stash_before 327 assert not xf.use_stash 328 329 for s, sr in xf.stash_before: 330 assert s not in stashes 331 if free_stash_ids: 332 sid = heapq.heappop(free_stash_ids) 333 else: 334 sid = next_stash_id 335 next_stash_id += 1 336 stashes[s] = sid 337 stashed_blocks += sr.size() 338 if self.version == 2: 339 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 340 else: 341 sh = self.HashBlocks(self.src, sr) 342 if sh in stashes: 343 stashes[sh] += 1 344 else: 345 stashes[sh] = 1 346 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 347 348 if stashed_blocks > max_stashed_blocks: 349 max_stashed_blocks = stashed_blocks 350 351 free_string = [] 352 353 if self.version == 1: 354 src_str = xf.src_ranges.to_string_raw() 355 elif self.version >= 2: 356 357 # <# blocks> <src ranges> 358 # OR 359 # <# blocks> <src ranges> <src locs> <stash refs...> 360 # OR 361 # <# blocks> - <stash refs...> 362 363 size = xf.src_ranges.size() 364 src_str = [str(size)] 365 366 unstashed_src_ranges = xf.src_ranges 367 mapped_stashes = [] 368 for s, sr in xf.use_stash: 369 sid = stashes.pop(s) 370 stashed_blocks -= sr.size() 371 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 372 sh = self.HashBlocks(self.src, sr) 373 sr = xf.src_ranges.map_within(sr) 374 mapped_stashes.append(sr) 375 if self.version == 2: 376 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 377 else: 378 assert sh in stashes 379 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 380 stashes[sh] -= 1 381 if stashes[sh] == 0: 382 free_string.append("free %s\n" % (sh)) 383 stashes.pop(sh) 384 heapq.heappush(free_stash_ids, sid) 385 386 if unstashed_src_ranges: 387 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 388 if xf.use_stash: 389 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 390 src_str.insert(2, mapped_unstashed.to_string_raw()) 391 mapped_stashes.append(mapped_unstashed) 392 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 393 else: 394 src_str.insert(1, "-") 395 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 396 397 src_str = " ".join(src_str) 398 399 # all versions: 400 # zero <rangeset> 401 # new <rangeset> 402 # erase <rangeset> 403 # 404 # version 1: 405 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 406 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 407 # move <src rangeset> <tgt rangeset> 408 # 409 # version 2: 410 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 411 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 412 # move <tgt rangeset> <src_str> 413 # 414 # version 3: 415 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 416 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 417 # move hash <tgt rangeset> <src_str> 418 419 tgt_size = xf.tgt_ranges.size() 420 421 if xf.style == "new": 422 assert xf.tgt_ranges 423 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 424 total += tgt_size 425 elif xf.style == "move": 426 assert xf.tgt_ranges 427 assert xf.src_ranges.size() == tgt_size 428 if xf.src_ranges != xf.tgt_ranges: 429 if self.version == 1: 430 out.append("%s %s %s\n" % ( 431 xf.style, 432 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 433 elif self.version == 2: 434 out.append("%s %s %s\n" % ( 435 xf.style, 436 xf.tgt_ranges.to_string_raw(), src_str)) 437 elif self.version >= 3: 438 # take into account automatic stashing of overlapping blocks 439 if xf.src_ranges.overlaps(xf.tgt_ranges): 440 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 441 if temp_stash_usage > max_stashed_blocks: 442 max_stashed_blocks = temp_stash_usage 443 444 out.append("%s %s %s %s\n" % ( 445 xf.style, 446 self.HashBlocks(self.tgt, xf.tgt_ranges), 447 xf.tgt_ranges.to_string_raw(), src_str)) 448 total += tgt_size 449 elif xf.style in ("bsdiff", "imgdiff"): 450 assert xf.tgt_ranges 451 assert xf.src_ranges 452 if self.version == 1: 453 out.append("%s %d %d %s %s\n" % ( 454 xf.style, xf.patch_start, xf.patch_len, 455 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 456 elif self.version == 2: 457 out.append("%s %d %d %s %s\n" % ( 458 xf.style, xf.patch_start, xf.patch_len, 459 xf.tgt_ranges.to_string_raw(), src_str)) 460 elif self.version >= 3: 461 # take into account automatic stashing of overlapping blocks 462 if xf.src_ranges.overlaps(xf.tgt_ranges): 463 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 464 if temp_stash_usage > max_stashed_blocks: 465 max_stashed_blocks = temp_stash_usage 466 467 out.append("%s %d %d %s %s %s %s\n" % ( 468 xf.style, 469 xf.patch_start, xf.patch_len, 470 self.HashBlocks(self.src, xf.src_ranges), 471 self.HashBlocks(self.tgt, xf.tgt_ranges), 472 xf.tgt_ranges.to_string_raw(), src_str)) 473 total += tgt_size 474 elif xf.style == "zero": 475 assert xf.tgt_ranges 476 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 477 if to_zero: 478 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 479 total += to_zero.size() 480 else: 481 raise ValueError("unknown transfer style '%s'\n" % xf.style) 482 483 if free_string: 484 out.append("".join(free_string)) 485 486 if self.version >= 2: 487 # Sanity check: abort if we're going to need more stash space than 488 # the allowed size (cache_size * threshold). There are two purposes 489 # of having a threshold here. a) Part of the cache may have been 490 # occupied by some recovery logs. b) It will buy us some time to deal 491 # with the oversize issue. 492 cache_size = common.OPTIONS.cache_size 493 stash_threshold = common.OPTIONS.stash_threshold 494 max_allowed = cache_size * stash_threshold 495 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 496 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 497 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 498 self.tgt.blocksize, max_allowed, cache_size, 499 stash_threshold) 500 501 # Zero out extended blocks as a workaround for bug 20881595. 502 if self.tgt.extended: 503 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 504 505 # We erase all the blocks on the partition that a) don't contain useful 506 # data in the new image and b) will not be touched by dm-verity. 507 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 508 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 509 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 510 if new_dontcare: 511 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 512 513 out.insert(0, "%d\n" % (self.version,)) # format version number 514 out.insert(1, str(total) + "\n") 515 if self.version >= 2: 516 # version 2 only: after the total block count, we give the number 517 # of stash slots needed, and the maximum size needed (in blocks) 518 out.insert(2, str(next_stash_id) + "\n") 519 out.insert(3, str(max_stashed_blocks) + "\n") 520 521 with open(prefix + ".transfer.list", "wb") as f: 522 for i in out: 523 f.write(i) 524 525 if self.version >= 2: 526 max_stashed_size = max_stashed_blocks * self.tgt.blocksize 527 max_allowed = common.OPTIONS.cache_size * common.OPTIONS.stash_threshold 528 print("max stashed blocks: %d (%d bytes), limit: %d bytes (%.2f%%)\n" % ( 529 max_stashed_blocks, max_stashed_size, max_allowed, 530 max_stashed_size * 100.0 / max_allowed)) 531 532 def ReviseStashSize(self): 533 print("Revising stash size...") 534 stashes = {} 535 536 # Create the map between a stash and its def/use points. For example, for a 537 # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd). 538 for xf in self.transfers: 539 # Command xf defines (stores) all the stashes in stash_before. 540 for idx, sr in xf.stash_before: 541 stashes[idx] = (sr, xf) 542 543 # Record all the stashes command xf uses. 544 for idx, _ in xf.use_stash: 545 stashes[idx] += (xf,) 546 547 # Compute the maximum blocks available for stash based on /cache size and 548 # the threshold. 549 cache_size = common.OPTIONS.cache_size 550 stash_threshold = common.OPTIONS.stash_threshold 551 max_allowed = cache_size * stash_threshold / self.tgt.blocksize 552 553 stashed_blocks = 0 554 new_blocks = 0 555 556 # Now go through all the commands. Compute the required stash size on the 557 # fly. If a command requires excess stash than available, it deletes the 558 # stash by replacing the command that uses the stash with a "new" command 559 # instead. 560 for xf in self.transfers: 561 replaced_cmds = [] 562 563 # xf.stash_before generates explicit stash commands. 564 for idx, sr in xf.stash_before: 565 if stashed_blocks + sr.size() > max_allowed: 566 # We cannot stash this one for a later command. Find out the command 567 # that will use this stash and replace the command with "new". 568 use_cmd = stashes[idx][2] 569 replaced_cmds.append(use_cmd) 570 print("%10d %9s %s" % (sr.size(), "explicit", use_cmd)) 571 else: 572 stashed_blocks += sr.size() 573 574 # xf.use_stash generates free commands. 575 for _, sr in xf.use_stash: 576 stashed_blocks -= sr.size() 577 578 # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to 579 # ComputePatches(), they both have the style of "diff". 580 if xf.style == "diff" and self.version >= 3: 581 assert xf.tgt_ranges and xf.src_ranges 582 if xf.src_ranges.overlaps(xf.tgt_ranges): 583 if stashed_blocks + xf.src_ranges.size() > max_allowed: 584 replaced_cmds.append(xf) 585 print("%10d %9s %s" % (xf.src_ranges.size(), "implicit", xf)) 586 587 # Replace the commands in replaced_cmds with "new"s. 588 for cmd in replaced_cmds: 589 # It no longer uses any commands in "use_stash". Remove the def points 590 # for all those stashes. 591 for idx, sr in cmd.use_stash: 592 def_cmd = stashes[idx][1] 593 assert (idx, sr) in def_cmd.stash_before 594 def_cmd.stash_before.remove((idx, sr)) 595 new_blocks += sr.size() 596 597 cmd.ConvertToNew() 598 599 print(" Total %d blocks are packed as new blocks due to insufficient " 600 "cache size." % (new_blocks,)) 601 602 def ComputePatches(self, prefix): 603 print("Reticulating splines...") 604 diff_q = [] 605 patch_num = 0 606 with open(prefix + ".new.dat", "wb") as new_f: 607 for xf in self.transfers: 608 if xf.style == "zero": 609 pass 610 elif xf.style == "new": 611 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 612 new_f.write(piece) 613 elif xf.style == "diff": 614 src = self.src.ReadRangeSet(xf.src_ranges) 615 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 616 617 # We can't compare src and tgt directly because they may have 618 # the same content but be broken up into blocks differently, eg: 619 # 620 # ["he", "llo"] vs ["h", "ello"] 621 # 622 # We want those to compare equal, ideally without having to 623 # actually concatenate the strings (these may be tens of 624 # megabytes). 625 626 src_sha1 = sha1() 627 for p in src: 628 src_sha1.update(p) 629 tgt_sha1 = sha1() 630 tgt_size = 0 631 for p in tgt: 632 tgt_sha1.update(p) 633 tgt_size += len(p) 634 635 if src_sha1.digest() == tgt_sha1.digest(): 636 # These are identical; we don't need to generate a patch, 637 # just issue copy commands on the device. 638 xf.style = "move" 639 else: 640 # For files in zip format (eg, APKs, JARs, etc.) we would 641 # like to use imgdiff -z if possible (because it usually 642 # produces significantly smaller patches than bsdiff). 643 # This is permissible if: 644 # 645 # - the source and target files are monotonic (ie, the 646 # data is stored with blocks in increasing order), and 647 # - we haven't removed any blocks from the source set. 648 # 649 # If these conditions are satisfied then appending all the 650 # blocks in the set together in order will produce a valid 651 # zip file (plus possibly extra zeros in the last block), 652 # which is what imgdiff needs to operate. (imgdiff is 653 # fine with extra zeros at the end of the file.) 654 imgdiff = (xf.intact and 655 xf.tgt_name.split(".")[-1].lower() 656 in ("apk", "jar", "zip")) 657 xf.style = "imgdiff" if imgdiff else "bsdiff" 658 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 659 patch_num += 1 660 661 else: 662 assert False, "unknown style " + xf.style 663 664 if diff_q: 665 if self.threads > 1: 666 print("Computing patches (using %d threads)..." % (self.threads,)) 667 else: 668 print("Computing patches...") 669 diff_q.sort() 670 671 patches = [None] * patch_num 672 673 # TODO: Rewrite with multiprocessing.ThreadPool? 674 lock = threading.Lock() 675 def diff_worker(): 676 while True: 677 with lock: 678 if not diff_q: 679 return 680 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 681 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 682 size = len(patch) 683 with lock: 684 patches[patchnum] = (patch, xf) 685 print("%10d %10d (%6.2f%%) %7s %s" % ( 686 size, tgt_size, size * 100.0 / tgt_size, xf.style, 687 xf.tgt_name if xf.tgt_name == xf.src_name else ( 688 xf.tgt_name + " (from " + xf.src_name + ")"))) 689 690 threads = [threading.Thread(target=diff_worker) 691 for _ in range(self.threads)] 692 for th in threads: 693 th.start() 694 while threads: 695 threads.pop().join() 696 else: 697 patches = [] 698 699 p = 0 700 with open(prefix + ".patch.dat", "wb") as patch_f: 701 for patch, xf in patches: 702 xf.patch_start = p 703 xf.patch_len = len(patch) 704 patch_f.write(patch) 705 p += len(patch) 706 707 def AssertSequenceGood(self): 708 # Simulate the sequences of transfers we will output, and check that: 709 # - we never read a block after writing it, and 710 # - we write every block we care about exactly once. 711 712 # Start with no blocks having been touched yet. 713 touched = RangeSet() 714 715 # Imagine processing the transfers in order. 716 for xf in self.transfers: 717 # Check that the input blocks for this transfer haven't yet been touched. 718 719 x = xf.src_ranges 720 if self.version >= 2: 721 for _, sr in xf.use_stash: 722 x = x.subtract(sr) 723 724 assert not touched.overlaps(x) 725 # Check that the output blocks for this transfer haven't yet been touched. 726 assert not touched.overlaps(xf.tgt_ranges) 727 # Touch all the blocks written by this transfer. 728 touched = touched.union(xf.tgt_ranges) 729 730 # Check that we've written every target block. 731 assert touched == self.tgt.care_map 732 733 def ImproveVertexSequence(self): 734 print("Improving vertex order...") 735 736 # At this point our digraph is acyclic; we reversed any edges that 737 # were backwards in the heuristically-generated sequence. The 738 # previously-generated order is still acceptable, but we hope to 739 # find a better order that needs less memory for stashed data. 740 # Now we do a topological sort to generate a new vertex order, 741 # using a greedy algorithm to choose which vertex goes next 742 # whenever we have a choice. 743 744 # Make a copy of the edge set; this copy will get destroyed by the 745 # algorithm. 746 for xf in self.transfers: 747 xf.incoming = xf.goes_after.copy() 748 xf.outgoing = xf.goes_before.copy() 749 750 L = [] # the new vertex order 751 752 # S is the set of sources in the remaining graph; we always choose 753 # the one that leaves the least amount of stashed data after it's 754 # executed. 755 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 756 if not u.incoming] 757 heapq.heapify(S) 758 759 while S: 760 _, _, xf = heapq.heappop(S) 761 L.append(xf) 762 for u in xf.outgoing: 763 del u.incoming[xf] 764 if not u.incoming: 765 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 766 767 # if this fails then our graph had a cycle. 768 assert len(L) == len(self.transfers) 769 770 self.transfers = L 771 for i, xf in enumerate(L): 772 xf.order = i 773 774 def RemoveBackwardEdges(self): 775 print("Removing backward edges...") 776 in_order = 0 777 out_of_order = 0 778 lost_source = 0 779 780 for xf in self.transfers: 781 lost = 0 782 size = xf.src_ranges.size() 783 for u in xf.goes_before: 784 # xf should go before u 785 if xf.order < u.order: 786 # it does, hurray! 787 in_order += 1 788 else: 789 # it doesn't, boo. trim the blocks that u writes from xf's 790 # source, so that xf can go after u. 791 out_of_order += 1 792 assert xf.src_ranges.overlaps(u.tgt_ranges) 793 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 794 xf.intact = False 795 796 if xf.style == "diff" and not xf.src_ranges: 797 # nothing left to diff from; treat as new data 798 xf.style = "new" 799 800 lost = size - xf.src_ranges.size() 801 lost_source += lost 802 803 print((" %d/%d dependencies (%.2f%%) were violated; " 804 "%d source blocks removed.") % 805 (out_of_order, in_order + out_of_order, 806 (out_of_order * 100.0 / (in_order + out_of_order)) 807 if (in_order + out_of_order) else 0.0, 808 lost_source)) 809 810 def ReverseBackwardEdges(self): 811 print("Reversing backward edges...") 812 in_order = 0 813 out_of_order = 0 814 stashes = 0 815 stash_size = 0 816 817 for xf in self.transfers: 818 for u in xf.goes_before.copy(): 819 # xf should go before u 820 if xf.order < u.order: 821 # it does, hurray! 822 in_order += 1 823 else: 824 # it doesn't, boo. modify u to stash the blocks that it 825 # writes that xf wants to read, and then require u to go 826 # before xf. 827 out_of_order += 1 828 829 overlap = xf.src_ranges.intersect(u.tgt_ranges) 830 assert overlap 831 832 u.stash_before.append((stashes, overlap)) 833 xf.use_stash.append((stashes, overlap)) 834 stashes += 1 835 stash_size += overlap.size() 836 837 # reverse the edge direction; now xf must go after u 838 del xf.goes_before[u] 839 del u.goes_after[xf] 840 xf.goes_after[u] = None # value doesn't matter 841 u.goes_before[xf] = None 842 843 print((" %d/%d dependencies (%.2f%%) were violated; " 844 "%d source blocks stashed.") % 845 (out_of_order, in_order + out_of_order, 846 (out_of_order * 100.0 / (in_order + out_of_order)) 847 if (in_order + out_of_order) else 0.0, 848 stash_size)) 849 850 def FindVertexSequence(self): 851 print("Finding vertex sequence...") 852 853 # This is based on "A Fast & Effective Heuristic for the Feedback 854 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 855 # it as starting with the digraph G and moving all the vertices to 856 # be on a horizontal line in some order, trying to minimize the 857 # number of edges that end up pointing to the left. Left-pointing 858 # edges will get removed to turn the digraph into a DAG. In this 859 # case each edge has a weight which is the number of source blocks 860 # we'll lose if that edge is removed; we try to minimize the total 861 # weight rather than just the number of edges. 862 863 # Make a copy of the edge set; this copy will get destroyed by the 864 # algorithm. 865 for xf in self.transfers: 866 xf.incoming = xf.goes_after.copy() 867 xf.outgoing = xf.goes_before.copy() 868 869 # We use an OrderedDict instead of just a set so that the output 870 # is repeatable; otherwise it would depend on the hash values of 871 # the transfer objects. 872 G = OrderedDict() 873 for xf in self.transfers: 874 G[xf] = None 875 s1 = deque() # the left side of the sequence, built from left to right 876 s2 = deque() # the right side of the sequence, built from right to left 877 878 while G: 879 880 # Put all sinks at the end of the sequence. 881 while True: 882 sinks = [u for u in G if not u.outgoing] 883 if not sinks: 884 break 885 for u in sinks: 886 s2.appendleft(u) 887 del G[u] 888 for iu in u.incoming: 889 del iu.outgoing[u] 890 891 # Put all the sources at the beginning of the sequence. 892 while True: 893 sources = [u for u in G if not u.incoming] 894 if not sources: 895 break 896 for u in sources: 897 s1.append(u) 898 del G[u] 899 for iu in u.outgoing: 900 del iu.incoming[u] 901 902 if not G: 903 break 904 905 # Find the "best" vertex to put next. "Best" is the one that 906 # maximizes the net difference in source blocks saved we get by 907 # pretending it's a source rather than a sink. 908 909 max_d = None 910 best_u = None 911 for u in G: 912 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 913 if best_u is None or d > max_d: 914 max_d = d 915 best_u = u 916 917 u = best_u 918 s1.append(u) 919 del G[u] 920 for iu in u.outgoing: 921 del iu.incoming[u] 922 for iu in u.incoming: 923 del iu.outgoing[u] 924 925 # Now record the sequence in the 'order' field of each transfer, 926 # and by rearranging self.transfers to be in the chosen sequence. 927 928 new_transfers = [] 929 for x in itertools.chain(s1, s2): 930 x.order = len(new_transfers) 931 new_transfers.append(x) 932 del x.incoming 933 del x.outgoing 934 935 self.transfers = new_transfers 936 937 def GenerateDigraph(self): 938 print("Generating digraph...") 939 for a in self.transfers: 940 for b in self.transfers: 941 if a is b: 942 continue 943 944 # If the blocks written by A are read by B, then B needs to go before A. 945 i = a.tgt_ranges.intersect(b.src_ranges) 946 if i: 947 if b.src_name == "__ZERO": 948 # the cost of removing source blocks for the __ZERO domain 949 # is (nearly) zero. 950 size = 0 951 else: 952 size = i.size() 953 b.goes_before[a] = size 954 a.goes_after[b] = size 955 956 def FindTransfers(self): 957 """Parse the file_map to generate all the transfers.""" 958 959 def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id, 960 split=False): 961 """Wrapper function for adding a Transfer(). 962 963 For BBOTA v3, we need to stash source blocks for resumable feature. 964 However, with the growth of file size and the shrink of the cache 965 partition source blocks are too large to be stashed. If a file occupies 966 too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it 967 into smaller pieces by getting multiple Transfer()s. 968 969 The downside is that after splitting, we can no longer use imgdiff but 970 only bsdiff.""" 971 972 MAX_BLOCKS_PER_DIFF_TRANSFER = 1024 973 974 # We care about diff transfers only. 975 if style != "diff" or not split: 976 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 977 return 978 979 # Change nothing for small files. 980 if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and 981 src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER): 982 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 983 return 984 985 pieces = 0 986 while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and 987 src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER): 988 tgt_split_name = "%s-%d" % (tgt_name, pieces) 989 src_split_name = "%s-%d" % (src_name, pieces) 990 tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 991 src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 992 Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style, 993 by_id) 994 995 tgt_ranges = tgt_ranges.subtract(tgt_first) 996 src_ranges = src_ranges.subtract(src_first) 997 pieces += 1 998 999 # Handle remaining blocks. 1000 if tgt_ranges.size() or src_ranges.size(): 1001 # Must be both non-empty. 1002 assert tgt_ranges.size() and src_ranges.size() 1003 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1004 src_split_name = "%s-%d" % (src_name, pieces) 1005 Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style, 1006 by_id) 1007 1008 empty = RangeSet() 1009 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 1010 if tgt_fn == "__ZERO": 1011 # the special "__ZERO" domain is all the blocks not contained 1012 # in any file and that are filled with zeros. We have a 1013 # special transfer style for zero blocks. 1014 src_ranges = self.src.file_map.get("__ZERO", empty) 1015 AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 1016 "zero", self.transfers) 1017 continue 1018 1019 elif tgt_fn == "__COPY": 1020 # "__COPY" domain includes all the blocks not contained in any 1021 # file and that need to be copied unconditionally to the target. 1022 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1023 continue 1024 1025 elif tgt_fn in self.src.file_map: 1026 # Look for an exact pathname match in the source. 1027 AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 1028 "diff", self.transfers, self.version >= 3) 1029 continue 1030 1031 b = os.path.basename(tgt_fn) 1032 if b in self.src_basenames: 1033 # Look for an exact basename match in the source. 1034 src_fn = self.src_basenames[b] 1035 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1036 "diff", self.transfers, self.version >= 3) 1037 continue 1038 1039 b = re.sub("[0-9]+", "#", b) 1040 if b in self.src_numpatterns: 1041 # Look for a 'number pattern' match (a basename match after 1042 # all runs of digits are replaced by "#"). (This is useful 1043 # for .so files that contain version numbers in the filename 1044 # that get bumped.) 1045 src_fn = self.src_numpatterns[b] 1046 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1047 "diff", self.transfers, self.version >= 3) 1048 continue 1049 1050 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1051 1052 def AbbreviateSourceNames(self): 1053 for k in self.src.file_map.keys(): 1054 b = os.path.basename(k) 1055 self.src_basenames[b] = k 1056 b = re.sub("[0-9]+", "#", b) 1057 self.src_numpatterns[b] = k 1058 1059 @staticmethod 1060 def AssertPartition(total, seq): 1061 """Assert that all the RangeSets in 'seq' form a partition of the 1062 'total' RangeSet (ie, they are nonintersecting and their union 1063 equals 'total').""" 1064 so_far = RangeSet() 1065 for i in seq: 1066 assert not so_far.overlaps(i) 1067 so_far = so_far.union(i) 1068 assert so_far == total 1069