blockimgdiff.py revision 937847ae493d36b9741f6387a2357d5cdceda3d9
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import common 20import heapq 21import itertools 22import multiprocessing 23import os 24import re 25import subprocess 26import threading 27import tempfile 28 29from rangelib import RangeSet 30 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34 35def compute_patch(src, tgt, imgdiff=False): 36 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 37 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 38 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 39 os.close(patchfd) 40 41 try: 42 with os.fdopen(srcfd, "wb") as f_src: 43 for p in src: 44 f_src.write(p) 45 46 with os.fdopen(tgtfd, "wb") as f_tgt: 47 for p in tgt: 48 f_tgt.write(p) 49 try: 50 os.unlink(patchfile) 51 except OSError: 52 pass 53 if imgdiff: 54 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 55 stdout=open("/dev/null", "a"), 56 stderr=subprocess.STDOUT) 57 else: 58 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 59 60 if p: 61 raise ValueError("diff failed: " + str(p)) 62 63 with open(patchfile, "rb") as f: 64 return f.read() 65 finally: 66 try: 67 os.unlink(srcfile) 68 os.unlink(tgtfile) 69 os.unlink(patchfile) 70 except OSError: 71 pass 72 73 74class Image(object): 75 def ReadRangeSet(self, ranges): 76 raise NotImplementedError 77 78 def TotalSha1(self, include_clobbered_blocks=False): 79 raise NotImplementedError 80 81 82class EmptyImage(Image): 83 """A zero-length image.""" 84 blocksize = 4096 85 care_map = RangeSet() 86 clobbered_blocks = RangeSet() 87 extended = RangeSet() 88 total_blocks = 0 89 file_map = {} 90 def ReadRangeSet(self, ranges): 91 return () 92 def TotalSha1(self, include_clobbered_blocks=False): 93 # EmptyImage always carries empty clobbered_blocks, so 94 # include_clobbered_blocks can be ignored. 95 assert self.clobbered_blocks.size() == 0 96 return sha1().hexdigest() 97 98 99class DataImage(Image): 100 """An image wrapped around a single string of data.""" 101 102 def __init__(self, data, trim=False, pad=False): 103 self.data = data 104 self.blocksize = 4096 105 106 assert not (trim and pad) 107 108 partial = len(self.data) % self.blocksize 109 if partial > 0: 110 if trim: 111 self.data = self.data[:-partial] 112 elif pad: 113 self.data += '\0' * (self.blocksize - partial) 114 else: 115 raise ValueError(("data for DataImage must be multiple of %d bytes " 116 "unless trim or pad is specified") % 117 (self.blocksize,)) 118 119 assert len(self.data) % self.blocksize == 0 120 121 self.total_blocks = len(self.data) / self.blocksize 122 self.care_map = RangeSet(data=(0, self.total_blocks)) 123 self.clobbered_blocks = RangeSet() 124 self.extended = RangeSet() 125 126 zero_blocks = [] 127 nonzero_blocks = [] 128 reference = '\0' * self.blocksize 129 130 for i in range(self.total_blocks): 131 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 132 if d == reference: 133 zero_blocks.append(i) 134 zero_blocks.append(i+1) 135 else: 136 nonzero_blocks.append(i) 137 nonzero_blocks.append(i+1) 138 139 self.file_map = {"__ZERO": RangeSet(zero_blocks), 140 "__NONZERO": RangeSet(nonzero_blocks)} 141 142 def ReadRangeSet(self, ranges): 143 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 144 145 def TotalSha1(self, include_clobbered_blocks=False): 146 # DataImage always carries empty clobbered_blocks, so 147 # include_clobbered_blocks can be ignored. 148 assert self.clobbered_blocks.size() == 0 149 return sha1(self.data).hexdigest() 150 151 152class Transfer(object): 153 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 154 self.tgt_name = tgt_name 155 self.src_name = src_name 156 self.tgt_ranges = tgt_ranges 157 self.src_ranges = src_ranges 158 self.style = style 159 self.intact = (getattr(tgt_ranges, "monotonic", False) and 160 getattr(src_ranges, "monotonic", False)) 161 162 # We use OrderedDict rather than dict so that the output is repeatable; 163 # otherwise it would depend on the hash values of the Transfer objects. 164 self.goes_before = OrderedDict() 165 self.goes_after = OrderedDict() 166 167 self.stash_before = [] 168 self.use_stash = [] 169 170 self.id = len(by_id) 171 by_id.append(self) 172 173 def NetStashChange(self): 174 return (sum(sr.size() for (_, sr) in self.stash_before) - 175 sum(sr.size() for (_, sr) in self.use_stash)) 176 177 def ConvertToNew(self): 178 assert self.style != "new" 179 self.use_stash = [] 180 self.style = "new" 181 self.src_ranges = RangeSet() 182 183 def __str__(self): 184 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 185 " to " + str(self.tgt_ranges) + ">") 186 187 188# BlockImageDiff works on two image objects. An image object is 189# anything that provides the following attributes: 190# 191# blocksize: the size in bytes of a block, currently must be 4096. 192# 193# total_blocks: the total size of the partition/image, in blocks. 194# 195# care_map: a RangeSet containing which blocks (in the range [0, 196# total_blocks) we actually care about; i.e. which blocks contain 197# data. 198# 199# file_map: a dict that partitions the blocks contained in care_map 200# into smaller domains that are useful for doing diffs on. 201# (Typically a domain is a file, and the key in file_map is the 202# pathname.) 203# 204# clobbered_blocks: a RangeSet containing which blocks contain data 205# but may be altered by the FS. They need to be excluded when 206# verifying the partition integrity. 207# 208# ReadRangeSet(): a function that takes a RangeSet and returns the 209# data contained in the image blocks of that RangeSet. The data 210# is returned as a list or tuple of strings; concatenating the 211# elements together should produce the requested data. 212# Implementations are free to break up the data into list/tuple 213# elements in any way that is convenient. 214# 215# TotalSha1(): a function that returns (as a hex string) the SHA-1 216# hash of all the data in the image (ie, all the blocks in the 217# care_map minus clobbered_blocks, or including the clobbered 218# blocks if include_clobbered_blocks is True). 219# 220# When creating a BlockImageDiff, the src image may be None, in which 221# case the list of transfers produced will never read from the 222# original image. 223 224class BlockImageDiff(object): 225 def __init__(self, tgt, src=None, threads=None, version=3): 226 if threads is None: 227 threads = multiprocessing.cpu_count() // 2 228 if threads == 0: 229 threads = 1 230 self.threads = threads 231 self.version = version 232 self.transfers = [] 233 self.src_basenames = {} 234 self.src_numpatterns = {} 235 236 assert version in (1, 2, 3) 237 238 self.tgt = tgt 239 if src is None: 240 src = EmptyImage() 241 self.src = src 242 243 # The updater code that installs the patch always uses 4k blocks. 244 assert tgt.blocksize == 4096 245 assert src.blocksize == 4096 246 247 # The range sets in each filemap should comprise a partition of 248 # the care map. 249 self.AssertPartition(src.care_map, src.file_map.values()) 250 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 251 252 def Compute(self, prefix): 253 # When looking for a source file to use as the diff input for a 254 # target file, we try: 255 # 1) an exact path match if available, otherwise 256 # 2) a exact basename match if available, otherwise 257 # 3) a basename match after all runs of digits are replaced by 258 # "#" if available, otherwise 259 # 4) we have no source for this target. 260 self.AbbreviateSourceNames() 261 self.FindTransfers() 262 263 # Find the ordering dependencies among transfers (this is O(n^2) 264 # in the number of transfers). 265 self.GenerateDigraph() 266 # Find a sequence of transfers that satisfies as many ordering 267 # dependencies as possible (heuristically). 268 self.FindVertexSequence() 269 # Fix up the ordering dependencies that the sequence didn't 270 # satisfy. 271 if self.version == 1: 272 self.RemoveBackwardEdges() 273 else: 274 self.ReverseBackwardEdges() 275 self.ImproveVertexSequence() 276 277 # Ensure the runtime stash size is under the limit. 278 if self.version >= 2 and common.OPTIONS.cache_size is not None: 279 self.ReviseStashSize() 280 281 # Double-check our work. 282 self.AssertSequenceGood() 283 284 self.ComputePatches(prefix) 285 self.WriteTransfers(prefix) 286 287 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 288 data = source.ReadRangeSet(ranges) 289 ctx = sha1() 290 291 for p in data: 292 ctx.update(p) 293 294 return ctx.hexdigest() 295 296 def WriteTransfers(self, prefix): 297 out = [] 298 299 total = 0 300 301 stashes = {} 302 stashed_blocks = 0 303 max_stashed_blocks = 0 304 305 free_stash_ids = [] 306 next_stash_id = 0 307 308 for xf in self.transfers: 309 310 if self.version < 2: 311 assert not xf.stash_before 312 assert not xf.use_stash 313 314 for s, sr in xf.stash_before: 315 assert s not in stashes 316 if free_stash_ids: 317 sid = heapq.heappop(free_stash_ids) 318 else: 319 sid = next_stash_id 320 next_stash_id += 1 321 stashes[s] = sid 322 stashed_blocks += sr.size() 323 if self.version == 2: 324 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 325 else: 326 sh = self.HashBlocks(self.src, sr) 327 if sh in stashes: 328 stashes[sh] += 1 329 else: 330 stashes[sh] = 1 331 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 332 333 if stashed_blocks > max_stashed_blocks: 334 max_stashed_blocks = stashed_blocks 335 336 free_string = [] 337 338 if self.version == 1: 339 src_str = xf.src_ranges.to_string_raw() 340 elif self.version >= 2: 341 342 # <# blocks> <src ranges> 343 # OR 344 # <# blocks> <src ranges> <src locs> <stash refs...> 345 # OR 346 # <# blocks> - <stash refs...> 347 348 size = xf.src_ranges.size() 349 src_str = [str(size)] 350 351 unstashed_src_ranges = xf.src_ranges 352 mapped_stashes = [] 353 for s, sr in xf.use_stash: 354 sid = stashes.pop(s) 355 stashed_blocks -= sr.size() 356 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 357 sh = self.HashBlocks(self.src, sr) 358 sr = xf.src_ranges.map_within(sr) 359 mapped_stashes.append(sr) 360 if self.version == 2: 361 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 362 else: 363 assert sh in stashes 364 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 365 stashes[sh] -= 1 366 if stashes[sh] == 0: 367 free_string.append("free %s\n" % (sh)) 368 stashes.pop(sh) 369 heapq.heappush(free_stash_ids, sid) 370 371 if unstashed_src_ranges: 372 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 373 if xf.use_stash: 374 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 375 src_str.insert(2, mapped_unstashed.to_string_raw()) 376 mapped_stashes.append(mapped_unstashed) 377 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 378 else: 379 src_str.insert(1, "-") 380 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 381 382 src_str = " ".join(src_str) 383 384 # all versions: 385 # zero <rangeset> 386 # new <rangeset> 387 # erase <rangeset> 388 # 389 # version 1: 390 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 391 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 392 # move <src rangeset> <tgt rangeset> 393 # 394 # version 2: 395 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 396 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 397 # move <tgt rangeset> <src_str> 398 # 399 # version 3: 400 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 401 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 402 # move hash <tgt rangeset> <src_str> 403 404 tgt_size = xf.tgt_ranges.size() 405 406 if xf.style == "new": 407 assert xf.tgt_ranges 408 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 409 total += tgt_size 410 elif xf.style == "move": 411 assert xf.tgt_ranges 412 assert xf.src_ranges.size() == tgt_size 413 if xf.src_ranges != xf.tgt_ranges: 414 if self.version == 1: 415 out.append("%s %s %s\n" % ( 416 xf.style, 417 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 418 elif self.version == 2: 419 out.append("%s %s %s\n" % ( 420 xf.style, 421 xf.tgt_ranges.to_string_raw(), src_str)) 422 elif self.version >= 3: 423 # take into account automatic stashing of overlapping blocks 424 if xf.src_ranges.overlaps(xf.tgt_ranges): 425 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 426 if temp_stash_usage > max_stashed_blocks: 427 max_stashed_blocks = temp_stash_usage 428 429 out.append("%s %s %s %s\n" % ( 430 xf.style, 431 self.HashBlocks(self.tgt, xf.tgt_ranges), 432 xf.tgt_ranges.to_string_raw(), src_str)) 433 total += tgt_size 434 elif xf.style in ("bsdiff", "imgdiff"): 435 assert xf.tgt_ranges 436 assert xf.src_ranges 437 if self.version == 1: 438 out.append("%s %d %d %s %s\n" % ( 439 xf.style, xf.patch_start, xf.patch_len, 440 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 441 elif self.version == 2: 442 out.append("%s %d %d %s %s\n" % ( 443 xf.style, xf.patch_start, xf.patch_len, 444 xf.tgt_ranges.to_string_raw(), src_str)) 445 elif self.version >= 3: 446 # take into account automatic stashing of overlapping blocks 447 if xf.src_ranges.overlaps(xf.tgt_ranges): 448 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 449 if temp_stash_usage > max_stashed_blocks: 450 max_stashed_blocks = temp_stash_usage 451 452 out.append("%s %d %d %s %s %s %s\n" % ( 453 xf.style, 454 xf.patch_start, xf.patch_len, 455 self.HashBlocks(self.src, xf.src_ranges), 456 self.HashBlocks(self.tgt, xf.tgt_ranges), 457 xf.tgt_ranges.to_string_raw(), src_str)) 458 total += tgt_size 459 elif xf.style == "zero": 460 assert xf.tgt_ranges 461 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 462 if to_zero: 463 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 464 total += to_zero.size() 465 else: 466 raise ValueError("unknown transfer style '%s'\n" % xf.style) 467 468 if free_string: 469 out.append("".join(free_string)) 470 471 if self.version >= 2: 472 # Sanity check: abort if we're going to need more stash space than 473 # the allowed size (cache_size * threshold). There are two purposes 474 # of having a threshold here. a) Part of the cache may have been 475 # occupied by some recovery logs. b) It will buy us some time to deal 476 # with the oversize issue. 477 cache_size = common.OPTIONS.cache_size 478 stash_threshold = common.OPTIONS.stash_threshold 479 max_allowed = cache_size * stash_threshold 480 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 481 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 482 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 483 self.tgt.blocksize, max_allowed, cache_size, 484 stash_threshold) 485 486 # Zero out extended blocks as a workaround for bug 20881595. 487 if self.tgt.extended: 488 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 489 490 # We erase all the blocks on the partition that a) don't contain useful 491 # data in the new image and b) will not be touched by dm-verity. 492 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 493 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 494 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 495 if new_dontcare: 496 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 497 498 out.insert(0, "%d\n" % (self.version,)) # format version number 499 out.insert(1, str(total) + "\n") 500 if self.version >= 2: 501 # version 2 only: after the total block count, we give the number 502 # of stash slots needed, and the maximum size needed (in blocks) 503 out.insert(2, str(next_stash_id) + "\n") 504 out.insert(3, str(max_stashed_blocks) + "\n") 505 506 with open(prefix + ".transfer.list", "wb") as f: 507 for i in out: 508 f.write(i) 509 510 if self.version >= 2: 511 max_stashed_size = max_stashed_blocks * self.tgt.blocksize 512 max_allowed = common.OPTIONS.cache_size * common.OPTIONS.stash_threshold 513 print("max stashed blocks: %d (%d bytes), limit: %d bytes (%.2f%%)\n" % ( 514 max_stashed_blocks, max_stashed_size, max_allowed, 515 max_stashed_size * 100.0 / max_allowed)) 516 517 def ReviseStashSize(self): 518 print("Revising stash size...") 519 stashes = {} 520 521 # Create the map between a stash and its def/use points. For example, for a 522 # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd). 523 for xf in self.transfers: 524 # Command xf defines (stores) all the stashes in stash_before. 525 for idx, sr in xf.stash_before: 526 stashes[idx] = (sr, xf) 527 528 # Record all the stashes command xf uses. 529 for idx, _ in xf.use_stash: 530 stashes[idx] += (xf,) 531 532 # Compute the maximum blocks available for stash based on /cache size and 533 # the threshold. 534 cache_size = common.OPTIONS.cache_size 535 stash_threshold = common.OPTIONS.stash_threshold 536 max_allowed = cache_size * stash_threshold / self.tgt.blocksize 537 538 stashed_blocks = 0 539 new_blocks = 0 540 541 # Now go through all the commands. Compute the required stash size on the 542 # fly. If a command requires excess stash than available, it deletes the 543 # stash by replacing the command that uses the stash with a "new" command 544 # instead. 545 for xf in self.transfers: 546 replaced_cmds = [] 547 548 # xf.stash_before generates explicit stash commands. 549 for idx, sr in xf.stash_before: 550 if stashed_blocks + sr.size() > max_allowed: 551 # We cannot stash this one for a later command. Find out the command 552 # that will use this stash and replace the command with "new". 553 use_cmd = stashes[idx][2] 554 replaced_cmds.append(use_cmd) 555 print("%10d %9s %s" % (sr.size(), "explicit", use_cmd)) 556 else: 557 stashed_blocks += sr.size() 558 559 # xf.use_stash generates free commands. 560 for _, sr in xf.use_stash: 561 stashed_blocks -= sr.size() 562 563 # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to 564 # ComputePatches(), they both have the style of "diff". 565 if xf.style == "diff" and self.version >= 3: 566 assert xf.tgt_ranges and xf.src_ranges 567 if xf.src_ranges.overlaps(xf.tgt_ranges): 568 if stashed_blocks + xf.src_ranges.size() > max_allowed: 569 replaced_cmds.append(xf) 570 print("%10d %9s %s" % (xf.src_ranges.size(), "implicit", xf)) 571 572 # Replace the commands in replaced_cmds with "new"s. 573 for cmd in replaced_cmds: 574 # It no longer uses any commands in "use_stash". Remove the def points 575 # for all those stashes. 576 for idx, sr in cmd.use_stash: 577 def_cmd = stashes[idx][1] 578 assert (idx, sr) in def_cmd.stash_before 579 def_cmd.stash_before.remove((idx, sr)) 580 new_blocks += sr.size() 581 582 cmd.ConvertToNew() 583 584 print(" Total %d blocks are packed as new blocks due to insufficient " 585 "cache size." % (new_blocks,)) 586 587 def ComputePatches(self, prefix): 588 print("Reticulating splines...") 589 diff_q = [] 590 patch_num = 0 591 with open(prefix + ".new.dat", "wb") as new_f: 592 for xf in self.transfers: 593 if xf.style == "zero": 594 pass 595 elif xf.style == "new": 596 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 597 new_f.write(piece) 598 elif xf.style == "diff": 599 src = self.src.ReadRangeSet(xf.src_ranges) 600 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 601 602 # We can't compare src and tgt directly because they may have 603 # the same content but be broken up into blocks differently, eg: 604 # 605 # ["he", "llo"] vs ["h", "ello"] 606 # 607 # We want those to compare equal, ideally without having to 608 # actually concatenate the strings (these may be tens of 609 # megabytes). 610 611 src_sha1 = sha1() 612 for p in src: 613 src_sha1.update(p) 614 tgt_sha1 = sha1() 615 tgt_size = 0 616 for p in tgt: 617 tgt_sha1.update(p) 618 tgt_size += len(p) 619 620 if src_sha1.digest() == tgt_sha1.digest(): 621 # These are identical; we don't need to generate a patch, 622 # just issue copy commands on the device. 623 xf.style = "move" 624 else: 625 # For files in zip format (eg, APKs, JARs, etc.) we would 626 # like to use imgdiff -z if possible (because it usually 627 # produces significantly smaller patches than bsdiff). 628 # This is permissible if: 629 # 630 # - the source and target files are monotonic (ie, the 631 # data is stored with blocks in increasing order), and 632 # - we haven't removed any blocks from the source set. 633 # 634 # If these conditions are satisfied then appending all the 635 # blocks in the set together in order will produce a valid 636 # zip file (plus possibly extra zeros in the last block), 637 # which is what imgdiff needs to operate. (imgdiff is 638 # fine with extra zeros at the end of the file.) 639 imgdiff = (xf.intact and 640 xf.tgt_name.split(".")[-1].lower() 641 in ("apk", "jar", "zip")) 642 xf.style = "imgdiff" if imgdiff else "bsdiff" 643 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 644 patch_num += 1 645 646 else: 647 assert False, "unknown style " + xf.style 648 649 if diff_q: 650 if self.threads > 1: 651 print("Computing patches (using %d threads)..." % (self.threads,)) 652 else: 653 print("Computing patches...") 654 diff_q.sort() 655 656 patches = [None] * patch_num 657 658 # TODO: Rewrite with multiprocessing.ThreadPool? 659 lock = threading.Lock() 660 def diff_worker(): 661 while True: 662 with lock: 663 if not diff_q: 664 return 665 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 666 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 667 size = len(patch) 668 with lock: 669 patches[patchnum] = (patch, xf) 670 print("%10d %10d (%6.2f%%) %7s %s" % ( 671 size, tgt_size, size * 100.0 / tgt_size, xf.style, 672 xf.tgt_name if xf.tgt_name == xf.src_name else ( 673 xf.tgt_name + " (from " + xf.src_name + ")"))) 674 675 threads = [threading.Thread(target=diff_worker) 676 for _ in range(self.threads)] 677 for th in threads: 678 th.start() 679 while threads: 680 threads.pop().join() 681 else: 682 patches = [] 683 684 p = 0 685 with open(prefix + ".patch.dat", "wb") as patch_f: 686 for patch, xf in patches: 687 xf.patch_start = p 688 xf.patch_len = len(patch) 689 patch_f.write(patch) 690 p += len(patch) 691 692 def AssertSequenceGood(self): 693 # Simulate the sequences of transfers we will output, and check that: 694 # - we never read a block after writing it, and 695 # - we write every block we care about exactly once. 696 697 # Start with no blocks having been touched yet. 698 touched = RangeSet() 699 700 # Imagine processing the transfers in order. 701 for xf in self.transfers: 702 # Check that the input blocks for this transfer haven't yet been touched. 703 704 x = xf.src_ranges 705 if self.version >= 2: 706 for _, sr in xf.use_stash: 707 x = x.subtract(sr) 708 709 assert not touched.overlaps(x) 710 # Check that the output blocks for this transfer haven't yet been touched. 711 assert not touched.overlaps(xf.tgt_ranges) 712 # Touch all the blocks written by this transfer. 713 touched = touched.union(xf.tgt_ranges) 714 715 # Check that we've written every target block. 716 assert touched == self.tgt.care_map 717 718 def ImproveVertexSequence(self): 719 print("Improving vertex order...") 720 721 # At this point our digraph is acyclic; we reversed any edges that 722 # were backwards in the heuristically-generated sequence. The 723 # previously-generated order is still acceptable, but we hope to 724 # find a better order that needs less memory for stashed data. 725 # Now we do a topological sort to generate a new vertex order, 726 # using a greedy algorithm to choose which vertex goes next 727 # whenever we have a choice. 728 729 # Make a copy of the edge set; this copy will get destroyed by the 730 # algorithm. 731 for xf in self.transfers: 732 xf.incoming = xf.goes_after.copy() 733 xf.outgoing = xf.goes_before.copy() 734 735 L = [] # the new vertex order 736 737 # S is the set of sources in the remaining graph; we always choose 738 # the one that leaves the least amount of stashed data after it's 739 # executed. 740 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 741 if not u.incoming] 742 heapq.heapify(S) 743 744 while S: 745 _, _, xf = heapq.heappop(S) 746 L.append(xf) 747 for u in xf.outgoing: 748 del u.incoming[xf] 749 if not u.incoming: 750 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 751 752 # if this fails then our graph had a cycle. 753 assert len(L) == len(self.transfers) 754 755 self.transfers = L 756 for i, xf in enumerate(L): 757 xf.order = i 758 759 def RemoveBackwardEdges(self): 760 print("Removing backward edges...") 761 in_order = 0 762 out_of_order = 0 763 lost_source = 0 764 765 for xf in self.transfers: 766 lost = 0 767 size = xf.src_ranges.size() 768 for u in xf.goes_before: 769 # xf should go before u 770 if xf.order < u.order: 771 # it does, hurray! 772 in_order += 1 773 else: 774 # it doesn't, boo. trim the blocks that u writes from xf's 775 # source, so that xf can go after u. 776 out_of_order += 1 777 assert xf.src_ranges.overlaps(u.tgt_ranges) 778 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 779 xf.intact = False 780 781 if xf.style == "diff" and not xf.src_ranges: 782 # nothing left to diff from; treat as new data 783 xf.style = "new" 784 785 lost = size - xf.src_ranges.size() 786 lost_source += lost 787 788 print((" %d/%d dependencies (%.2f%%) were violated; " 789 "%d source blocks removed.") % 790 (out_of_order, in_order + out_of_order, 791 (out_of_order * 100.0 / (in_order + out_of_order)) 792 if (in_order + out_of_order) else 0.0, 793 lost_source)) 794 795 def ReverseBackwardEdges(self): 796 print("Reversing backward edges...") 797 in_order = 0 798 out_of_order = 0 799 stashes = 0 800 stash_size = 0 801 802 for xf in self.transfers: 803 for u in xf.goes_before.copy(): 804 # xf should go before u 805 if xf.order < u.order: 806 # it does, hurray! 807 in_order += 1 808 else: 809 # it doesn't, boo. modify u to stash the blocks that it 810 # writes that xf wants to read, and then require u to go 811 # before xf. 812 out_of_order += 1 813 814 overlap = xf.src_ranges.intersect(u.tgt_ranges) 815 assert overlap 816 817 u.stash_before.append((stashes, overlap)) 818 xf.use_stash.append((stashes, overlap)) 819 stashes += 1 820 stash_size += overlap.size() 821 822 # reverse the edge direction; now xf must go after u 823 del xf.goes_before[u] 824 del u.goes_after[xf] 825 xf.goes_after[u] = None # value doesn't matter 826 u.goes_before[xf] = None 827 828 print((" %d/%d dependencies (%.2f%%) were violated; " 829 "%d source blocks stashed.") % 830 (out_of_order, in_order + out_of_order, 831 (out_of_order * 100.0 / (in_order + out_of_order)) 832 if (in_order + out_of_order) else 0.0, 833 stash_size)) 834 835 def FindVertexSequence(self): 836 print("Finding vertex sequence...") 837 838 # This is based on "A Fast & Effective Heuristic for the Feedback 839 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 840 # it as starting with the digraph G and moving all the vertices to 841 # be on a horizontal line in some order, trying to minimize the 842 # number of edges that end up pointing to the left. Left-pointing 843 # edges will get removed to turn the digraph into a DAG. In this 844 # case each edge has a weight which is the number of source blocks 845 # we'll lose if that edge is removed; we try to minimize the total 846 # weight rather than just the number of edges. 847 848 # Make a copy of the edge set; this copy will get destroyed by the 849 # algorithm. 850 for xf in self.transfers: 851 xf.incoming = xf.goes_after.copy() 852 xf.outgoing = xf.goes_before.copy() 853 854 # We use an OrderedDict instead of just a set so that the output 855 # is repeatable; otherwise it would depend on the hash values of 856 # the transfer objects. 857 G = OrderedDict() 858 for xf in self.transfers: 859 G[xf] = None 860 s1 = deque() # the left side of the sequence, built from left to right 861 s2 = deque() # the right side of the sequence, built from right to left 862 863 while G: 864 865 # Put all sinks at the end of the sequence. 866 while True: 867 sinks = [u for u in G if not u.outgoing] 868 if not sinks: 869 break 870 for u in sinks: 871 s2.appendleft(u) 872 del G[u] 873 for iu in u.incoming: 874 del iu.outgoing[u] 875 876 # Put all the sources at the beginning of the sequence. 877 while True: 878 sources = [u for u in G if not u.incoming] 879 if not sources: 880 break 881 for u in sources: 882 s1.append(u) 883 del G[u] 884 for iu in u.outgoing: 885 del iu.incoming[u] 886 887 if not G: 888 break 889 890 # Find the "best" vertex to put next. "Best" is the one that 891 # maximizes the net difference in source blocks saved we get by 892 # pretending it's a source rather than a sink. 893 894 max_d = None 895 best_u = None 896 for u in G: 897 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 898 if best_u is None or d > max_d: 899 max_d = d 900 best_u = u 901 902 u = best_u 903 s1.append(u) 904 del G[u] 905 for iu in u.outgoing: 906 del iu.incoming[u] 907 for iu in u.incoming: 908 del iu.outgoing[u] 909 910 # Now record the sequence in the 'order' field of each transfer, 911 # and by rearranging self.transfers to be in the chosen sequence. 912 913 new_transfers = [] 914 for x in itertools.chain(s1, s2): 915 x.order = len(new_transfers) 916 new_transfers.append(x) 917 del x.incoming 918 del x.outgoing 919 920 self.transfers = new_transfers 921 922 def GenerateDigraph(self): 923 print("Generating digraph...") 924 for a in self.transfers: 925 for b in self.transfers: 926 if a is b: 927 continue 928 929 # If the blocks written by A are read by B, then B needs to go before A. 930 i = a.tgt_ranges.intersect(b.src_ranges) 931 if i: 932 if b.src_name == "__ZERO": 933 # the cost of removing source blocks for the __ZERO domain 934 # is (nearly) zero. 935 size = 0 936 else: 937 size = i.size() 938 b.goes_before[a] = size 939 a.goes_after[b] = size 940 941 def FindTransfers(self): 942 """Parse the file_map to generate all the transfers.""" 943 944 def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id, 945 split=False): 946 """Wrapper function for adding a Transfer(). 947 948 For BBOTA v3, we need to stash source blocks for resumable feature. 949 However, with the growth of file size and the shrink of the cache 950 partition source blocks are too large to be stashed. If a file occupies 951 too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it 952 into smaller pieces by getting multiple Transfer()s. 953 954 The downside is that after splitting, we can no longer use imgdiff but 955 only bsdiff.""" 956 957 MAX_BLOCKS_PER_DIFF_TRANSFER = 1024 958 959 # We care about diff transfers only. 960 if style != "diff" or not split: 961 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 962 return 963 964 # Change nothing for small files. 965 if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and 966 src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER): 967 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 968 return 969 970 pieces = 0 971 while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and 972 src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER): 973 tgt_split_name = "%s-%d" % (tgt_name, pieces) 974 src_split_name = "%s-%d" % (src_name, pieces) 975 tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 976 src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 977 Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style, 978 by_id) 979 980 tgt_ranges = tgt_ranges.subtract(tgt_first) 981 src_ranges = src_ranges.subtract(src_first) 982 pieces += 1 983 984 # Handle remaining blocks. 985 if tgt_ranges.size() or src_ranges.size(): 986 # Must be both non-empty. 987 assert tgt_ranges.size() and src_ranges.size() 988 tgt_split_name = "%s-%d" % (tgt_name, pieces) 989 src_split_name = "%s-%d" % (src_name, pieces) 990 Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style, 991 by_id) 992 993 empty = RangeSet() 994 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 995 if tgt_fn == "__ZERO": 996 # the special "__ZERO" domain is all the blocks not contained 997 # in any file and that are filled with zeros. We have a 998 # special transfer style for zero blocks. 999 src_ranges = self.src.file_map.get("__ZERO", empty) 1000 AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 1001 "zero", self.transfers) 1002 continue 1003 1004 elif tgt_fn == "__COPY": 1005 # "__COPY" domain includes all the blocks not contained in any 1006 # file and that need to be copied unconditionally to the target. 1007 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1008 continue 1009 1010 elif tgt_fn in self.src.file_map: 1011 # Look for an exact pathname match in the source. 1012 AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 1013 "diff", self.transfers, self.version >= 3) 1014 continue 1015 1016 b = os.path.basename(tgt_fn) 1017 if b in self.src_basenames: 1018 # Look for an exact basename match in the source. 1019 src_fn = self.src_basenames[b] 1020 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1021 "diff", self.transfers, self.version >= 3) 1022 continue 1023 1024 b = re.sub("[0-9]+", "#", b) 1025 if b in self.src_numpatterns: 1026 # Look for a 'number pattern' match (a basename match after 1027 # all runs of digits are replaced by "#"). (This is useful 1028 # for .so files that contain version numbers in the filename 1029 # that get bumped.) 1030 src_fn = self.src_numpatterns[b] 1031 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1032 "diff", self.transfers, self.version >= 3) 1033 continue 1034 1035 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1036 1037 def AbbreviateSourceNames(self): 1038 for k in self.src.file_map.keys(): 1039 b = os.path.basename(k) 1040 self.src_basenames[b] = k 1041 b = re.sub("[0-9]+", "#", b) 1042 self.src_numpatterns[b] = k 1043 1044 @staticmethod 1045 def AssertPartition(total, seq): 1046 """Assert that all the RangeSets in 'seq' form a partition of the 1047 'total' RangeSet (ie, they are nonintersecting and their union 1048 equals 'total').""" 1049 so_far = RangeSet() 1050 for i in seq: 1051 assert not so_far.overlaps(i) 1052 so_far = so_far.union(i) 1053 assert so_far == total 1054