blockimgdiff.py revision 66f1fa65658d33f9bdc42691738a5fb6559560cb
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import array 20import common 21import functools 22import heapq 23import itertools 24import multiprocessing 25import os 26import re 27import subprocess 28import threading 29import time 30import tempfile 31 32from rangelib import RangeSet 33 34 35__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 36 37 38def compute_patch(src, tgt, imgdiff=False): 39 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 40 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 41 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 42 os.close(patchfd) 43 44 try: 45 with os.fdopen(srcfd, "wb") as f_src: 46 for p in src: 47 f_src.write(p) 48 49 with os.fdopen(tgtfd, "wb") as f_tgt: 50 for p in tgt: 51 f_tgt.write(p) 52 try: 53 os.unlink(patchfile) 54 except OSError: 55 pass 56 if imgdiff: 57 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 58 stdout=open("/dev/null", "a"), 59 stderr=subprocess.STDOUT) 60 else: 61 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 62 63 if p: 64 raise ValueError("diff failed: " + str(p)) 65 66 with open(patchfile, "rb") as f: 67 return f.read() 68 finally: 69 try: 70 os.unlink(srcfile) 71 os.unlink(tgtfile) 72 os.unlink(patchfile) 73 except OSError: 74 pass 75 76 77class Image(object): 78 def ReadRangeSet(self, ranges): 79 raise NotImplementedError 80 81 def TotalSha1(self, include_clobbered_blocks=False): 82 raise NotImplementedError 83 84 85class EmptyImage(Image): 86 """A zero-length image.""" 87 blocksize = 4096 88 care_map = RangeSet() 89 clobbered_blocks = RangeSet() 90 extended = RangeSet() 91 total_blocks = 0 92 file_map = {} 93 def ReadRangeSet(self, ranges): 94 return () 95 def TotalSha1(self, include_clobbered_blocks=False): 96 # EmptyImage always carries empty clobbered_blocks, so 97 # include_clobbered_blocks can be ignored. 98 assert self.clobbered_blocks.size() == 0 99 return sha1().hexdigest() 100 101 102class DataImage(Image): 103 """An image wrapped around a single string of data.""" 104 105 def __init__(self, data, trim=False, pad=False): 106 self.data = data 107 self.blocksize = 4096 108 109 assert not (trim and pad) 110 111 partial = len(self.data) % self.blocksize 112 padded = False 113 if partial > 0: 114 if trim: 115 self.data = self.data[:-partial] 116 elif pad: 117 self.data += '\0' * (self.blocksize - partial) 118 padded = True 119 else: 120 raise ValueError(("data for DataImage must be multiple of %d bytes " 121 "unless trim or pad is specified") % 122 (self.blocksize,)) 123 124 assert len(self.data) % self.blocksize == 0 125 126 self.total_blocks = len(self.data) / self.blocksize 127 self.care_map = RangeSet(data=(0, self.total_blocks)) 128 # When the last block is padded, we always write the whole block even for 129 # incremental OTAs. Because otherwise the last block may get skipped if 130 # unchanged for an incremental, but would fail the post-install 131 # verification if it has non-zero contents in the padding bytes. 132 # Bug: 23828506 133 if padded: 134 clobbered_blocks = [self.total_blocks-1, self.total_blocks] 135 else: 136 clobbered_blocks = [] 137 self.clobbered_blocks = clobbered_blocks 138 self.extended = RangeSet() 139 140 zero_blocks = [] 141 nonzero_blocks = [] 142 reference = '\0' * self.blocksize 143 144 for i in range(self.total_blocks-1 if padded else self.total_blocks): 145 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 146 if d == reference: 147 zero_blocks.append(i) 148 zero_blocks.append(i+1) 149 else: 150 nonzero_blocks.append(i) 151 nonzero_blocks.append(i+1) 152 153 assert zero_blocks or nonzero_blocks or clobbered_blocks 154 155 self.file_map = dict() 156 if zero_blocks: 157 self.file_map["__ZERO"] = RangeSet(data=zero_blocks) 158 if nonzero_blocks: 159 self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) 160 if clobbered_blocks: 161 self.file_map["__COPY"] = RangeSet(data=clobbered_blocks) 162 163 def ReadRangeSet(self, ranges): 164 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 165 166 def TotalSha1(self, include_clobbered_blocks=False): 167 if not include_clobbered_blocks: 168 ranges = self.care_map.subtract(self.clobbered_blocks) 169 return sha1(self.ReadRangeSet(ranges)).hexdigest() 170 else: 171 return sha1(self.data).hexdigest() 172 173 174class Transfer(object): 175 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 176 self.tgt_name = tgt_name 177 self.src_name = src_name 178 self.tgt_ranges = tgt_ranges 179 self.src_ranges = src_ranges 180 self.style = style 181 self.intact = (getattr(tgt_ranges, "monotonic", False) and 182 getattr(src_ranges, "monotonic", False)) 183 184 # We use OrderedDict rather than dict so that the output is repeatable; 185 # otherwise it would depend on the hash values of the Transfer objects. 186 self.goes_before = OrderedDict() 187 self.goes_after = OrderedDict() 188 189 self.stash_before = [] 190 self.use_stash = [] 191 192 self.id = len(by_id) 193 by_id.append(self) 194 195 def NetStashChange(self): 196 return (sum(sr.size() for (_, sr) in self.stash_before) - 197 sum(sr.size() for (_, sr) in self.use_stash)) 198 199 def ConvertToNew(self): 200 assert self.style != "new" 201 self.use_stash = [] 202 self.style = "new" 203 self.src_ranges = RangeSet() 204 205 def __str__(self): 206 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 207 " to " + str(self.tgt_ranges) + ">") 208 209 210@functools.total_ordering 211class HeapItem(object): 212 def __init__(self, item): 213 self.item = item 214 # Negate the score since python's heap is a min-heap and we want 215 # the maximum score. 216 self.score = -item.score 217 def clear(self): 218 self.item = None 219 def __bool__(self): 220 return self.item is None 221 def __eq__(self, other): 222 return self.score == other.score 223 def __le__(self, other): 224 return self.score <= other.score 225 226 227# BlockImageDiff works on two image objects. An image object is 228# anything that provides the following attributes: 229# 230# blocksize: the size in bytes of a block, currently must be 4096. 231# 232# total_blocks: the total size of the partition/image, in blocks. 233# 234# care_map: a RangeSet containing which blocks (in the range [0, 235# total_blocks) we actually care about; i.e. which blocks contain 236# data. 237# 238# file_map: a dict that partitions the blocks contained in care_map 239# into smaller domains that are useful for doing diffs on. 240# (Typically a domain is a file, and the key in file_map is the 241# pathname.) 242# 243# clobbered_blocks: a RangeSet containing which blocks contain data 244# but may be altered by the FS. They need to be excluded when 245# verifying the partition integrity. 246# 247# ReadRangeSet(): a function that takes a RangeSet and returns the 248# data contained in the image blocks of that RangeSet. The data 249# is returned as a list or tuple of strings; concatenating the 250# elements together should produce the requested data. 251# Implementations are free to break up the data into list/tuple 252# elements in any way that is convenient. 253# 254# TotalSha1(): a function that returns (as a hex string) the SHA-1 255# hash of all the data in the image (ie, all the blocks in the 256# care_map minus clobbered_blocks, or including the clobbered 257# blocks if include_clobbered_blocks is True). 258# 259# When creating a BlockImageDiff, the src image may be None, in which 260# case the list of transfers produced will never read from the 261# original image. 262 263class BlockImageDiff(object): 264 def __init__(self, tgt, src=None, threads=None, version=4): 265 if threads is None: 266 threads = multiprocessing.cpu_count() // 2 267 if threads == 0: 268 threads = 1 269 self.threads = threads 270 self.version = version 271 self.transfers = [] 272 self.src_basenames = {} 273 self.src_numpatterns = {} 274 self._max_stashed_size = 0 275 self.touched_src_ranges = RangeSet() 276 self.touched_src_sha1 = None 277 278 assert version in (1, 2, 3, 4) 279 280 self.tgt = tgt 281 if src is None: 282 src = EmptyImage() 283 self.src = src 284 285 # The updater code that installs the patch always uses 4k blocks. 286 assert tgt.blocksize == 4096 287 assert src.blocksize == 4096 288 289 # The range sets in each filemap should comprise a partition of 290 # the care map. 291 self.AssertPartition(src.care_map, src.file_map.values()) 292 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 293 294 @property 295 def max_stashed_size(self): 296 return self._max_stashed_size 297 298 def Compute(self, prefix): 299 # When looking for a source file to use as the diff input for a 300 # target file, we try: 301 # 1) an exact path match if available, otherwise 302 # 2) a exact basename match if available, otherwise 303 # 3) a basename match after all runs of digits are replaced by 304 # "#" if available, otherwise 305 # 4) we have no source for this target. 306 self.AbbreviateSourceNames() 307 self.FindTransfers() 308 309 # Find the ordering dependencies among transfers (this is O(n^2) 310 # in the number of transfers). 311 self.GenerateDigraph() 312 # Find a sequence of transfers that satisfies as many ordering 313 # dependencies as possible (heuristically). 314 self.FindVertexSequence() 315 # Fix up the ordering dependencies that the sequence didn't 316 # satisfy. 317 if self.version == 1: 318 self.RemoveBackwardEdges() 319 else: 320 self.ReverseBackwardEdges() 321 self.ImproveVertexSequence() 322 323 # Ensure the runtime stash size is under the limit. 324 if self.version >= 2 and common.OPTIONS.cache_size is not None: 325 self.ReviseStashSize() 326 327 # Double-check our work. 328 self.AssertSequenceGood() 329 330 self.ComputePatches(prefix) 331 self.WriteTransfers(prefix) 332 333 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 334 data = source.ReadRangeSet(ranges) 335 ctx = sha1() 336 337 for p in data: 338 ctx.update(p) 339 340 return ctx.hexdigest() 341 342 def WriteTransfers(self, prefix): 343 out = [] 344 345 total = 0 346 347 stashes = {} 348 stashed_blocks = 0 349 max_stashed_blocks = 0 350 351 free_stash_ids = [] 352 next_stash_id = 0 353 354 for xf in self.transfers: 355 356 if self.version < 2: 357 assert not xf.stash_before 358 assert not xf.use_stash 359 360 for s, sr in xf.stash_before: 361 assert s not in stashes 362 if free_stash_ids: 363 sid = heapq.heappop(free_stash_ids) 364 else: 365 sid = next_stash_id 366 next_stash_id += 1 367 stashes[s] = sid 368 if self.version == 2: 369 stashed_blocks += sr.size() 370 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 371 else: 372 sh = self.HashBlocks(self.src, sr) 373 if sh in stashes: 374 stashes[sh] += 1 375 else: 376 stashes[sh] = 1 377 stashed_blocks += sr.size() 378 self.touched_src_ranges = self.touched_src_ranges.union(sr) 379 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 380 381 if stashed_blocks > max_stashed_blocks: 382 max_stashed_blocks = stashed_blocks 383 384 free_string = [] 385 free_size = 0 386 387 if self.version == 1: 388 src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else "" 389 elif self.version >= 2: 390 391 # <# blocks> <src ranges> 392 # OR 393 # <# blocks> <src ranges> <src locs> <stash refs...> 394 # OR 395 # <# blocks> - <stash refs...> 396 397 size = xf.src_ranges.size() 398 src_str = [str(size)] 399 400 unstashed_src_ranges = xf.src_ranges 401 mapped_stashes = [] 402 for s, sr in xf.use_stash: 403 sid = stashes.pop(s) 404 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 405 sh = self.HashBlocks(self.src, sr) 406 sr = xf.src_ranges.map_within(sr) 407 mapped_stashes.append(sr) 408 if self.version == 2: 409 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 410 # A stash will be used only once. We need to free the stash 411 # immediately after the use, instead of waiting for the automatic 412 # clean-up at the end. Because otherwise it may take up extra space 413 # and lead to OTA failures. 414 # Bug: 23119955 415 free_string.append("free %d\n" % (sid,)) 416 free_size += sr.size() 417 else: 418 assert sh in stashes 419 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 420 stashes[sh] -= 1 421 if stashes[sh] == 0: 422 free_size += sr.size() 423 free_string.append("free %s\n" % (sh)) 424 stashes.pop(sh) 425 heapq.heappush(free_stash_ids, sid) 426 427 if unstashed_src_ranges: 428 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 429 if xf.use_stash: 430 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 431 src_str.insert(2, mapped_unstashed.to_string_raw()) 432 mapped_stashes.append(mapped_unstashed) 433 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 434 else: 435 src_str.insert(1, "-") 436 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 437 438 src_str = " ".join(src_str) 439 440 # all versions: 441 # zero <rangeset> 442 # new <rangeset> 443 # erase <rangeset> 444 # 445 # version 1: 446 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 447 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 448 # move <src rangeset> <tgt rangeset> 449 # 450 # version 2: 451 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 452 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 453 # move <tgt rangeset> <src_str> 454 # 455 # version 3: 456 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 457 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 458 # move hash <tgt rangeset> <src_str> 459 460 tgt_size = xf.tgt_ranges.size() 461 462 if xf.style == "new": 463 assert xf.tgt_ranges 464 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 465 total += tgt_size 466 elif xf.style == "move": 467 assert xf.tgt_ranges 468 assert xf.src_ranges.size() == tgt_size 469 if xf.src_ranges != xf.tgt_ranges: 470 if self.version == 1: 471 out.append("%s %s %s\n" % ( 472 xf.style, 473 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 474 elif self.version == 2: 475 out.append("%s %s %s\n" % ( 476 xf.style, 477 xf.tgt_ranges.to_string_raw(), src_str)) 478 elif self.version >= 3: 479 # take into account automatic stashing of overlapping blocks 480 if xf.src_ranges.overlaps(xf.tgt_ranges): 481 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 482 if temp_stash_usage > max_stashed_blocks: 483 max_stashed_blocks = temp_stash_usage 484 485 self.touched_src_ranges = self.touched_src_ranges.union( 486 xf.src_ranges) 487 488 out.append("%s %s %s %s\n" % ( 489 xf.style, 490 self.HashBlocks(self.tgt, xf.tgt_ranges), 491 xf.tgt_ranges.to_string_raw(), src_str)) 492 total += tgt_size 493 elif xf.style in ("bsdiff", "imgdiff"): 494 assert xf.tgt_ranges 495 assert xf.src_ranges 496 if self.version == 1: 497 out.append("%s %d %d %s %s\n" % ( 498 xf.style, xf.patch_start, xf.patch_len, 499 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 500 elif self.version == 2: 501 out.append("%s %d %d %s %s\n" % ( 502 xf.style, xf.patch_start, xf.patch_len, 503 xf.tgt_ranges.to_string_raw(), src_str)) 504 elif self.version >= 3: 505 # take into account automatic stashing of overlapping blocks 506 if xf.src_ranges.overlaps(xf.tgt_ranges): 507 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 508 if temp_stash_usage > max_stashed_blocks: 509 max_stashed_blocks = temp_stash_usage 510 511 self.touched_src_ranges = self.touched_src_ranges.union( 512 xf.src_ranges) 513 514 out.append("%s %d %d %s %s %s %s\n" % ( 515 xf.style, 516 xf.patch_start, xf.patch_len, 517 self.HashBlocks(self.src, xf.src_ranges), 518 self.HashBlocks(self.tgt, xf.tgt_ranges), 519 xf.tgt_ranges.to_string_raw(), src_str)) 520 total += tgt_size 521 elif xf.style == "zero": 522 assert xf.tgt_ranges 523 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 524 if to_zero: 525 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 526 total += to_zero.size() 527 else: 528 raise ValueError("unknown transfer style '%s'\n" % xf.style) 529 530 if free_string: 531 out.append("".join(free_string)) 532 stashed_blocks -= free_size 533 534 if self.version >= 2 and common.OPTIONS.cache_size is not None: 535 # Sanity check: abort if we're going to need more stash space than 536 # the allowed size (cache_size * threshold). There are two purposes 537 # of having a threshold here. a) Part of the cache may have been 538 # occupied by some recovery logs. b) It will buy us some time to deal 539 # with the oversize issue. 540 cache_size = common.OPTIONS.cache_size 541 stash_threshold = common.OPTIONS.stash_threshold 542 max_allowed = cache_size * stash_threshold 543 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 544 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 545 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 546 self.tgt.blocksize, max_allowed, cache_size, 547 stash_threshold) 548 549 if self.version >= 3: 550 self.touched_src_sha1 = self.HashBlocks( 551 self.src, self.touched_src_ranges) 552 553 # Zero out extended blocks as a workaround for bug 20881595. 554 if self.tgt.extended: 555 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 556 total += self.tgt.extended.size() 557 558 # We erase all the blocks on the partition that a) don't contain useful 559 # data in the new image; b) will not be touched by dm-verity. Out of those 560 # blocks, we erase the ones that won't be used in this update at the 561 # beginning of an update. The rest would be erased at the end. This is to 562 # work around the eMMC issue observed on some devices, which may otherwise 563 # get starving for clean blocks and thus fail the update. (b/28347095) 564 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 565 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 566 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 567 568 erase_first = new_dontcare.subtract(self.touched_src_ranges) 569 if erase_first: 570 out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),)) 571 572 erase_last = new_dontcare.subtract(erase_first) 573 if erase_last: 574 out.append("erase %s\n" % (erase_last.to_string_raw(),)) 575 576 out.insert(0, "%d\n" % (self.version,)) # format version number 577 out.insert(1, "%d\n" % (total,)) 578 if self.version >= 2: 579 # version 2 only: after the total block count, we give the number 580 # of stash slots needed, and the maximum size needed (in blocks) 581 out.insert(2, str(next_stash_id) + "\n") 582 out.insert(3, str(max_stashed_blocks) + "\n") 583 584 with open(prefix + ".transfer.list", "wb") as f: 585 for i in out: 586 f.write(i) 587 588 if self.version >= 2: 589 self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize 590 OPTIONS = common.OPTIONS 591 if OPTIONS.cache_size is not None: 592 max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold 593 print("max stashed blocks: %d (%d bytes), " 594 "limit: %d bytes (%.2f%%)\n" % ( 595 max_stashed_blocks, self._max_stashed_size, max_allowed, 596 self._max_stashed_size * 100.0 / max_allowed)) 597 else: 598 print("max stashed blocks: %d (%d bytes), limit: <unknown>\n" % ( 599 max_stashed_blocks, self._max_stashed_size)) 600 601 def ReviseStashSize(self): 602 print("Revising stash size...") 603 stashes = {} 604 605 # Create the map between a stash and its def/use points. For example, for a 606 # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd). 607 for xf in self.transfers: 608 # Command xf defines (stores) all the stashes in stash_before. 609 for idx, sr in xf.stash_before: 610 stashes[idx] = (sr, xf) 611 612 # Record all the stashes command xf uses. 613 for idx, _ in xf.use_stash: 614 stashes[idx] += (xf,) 615 616 # Compute the maximum blocks available for stash based on /cache size and 617 # the threshold. 618 cache_size = common.OPTIONS.cache_size 619 stash_threshold = common.OPTIONS.stash_threshold 620 max_allowed = cache_size * stash_threshold / self.tgt.blocksize 621 622 stashed_blocks = 0 623 new_blocks = 0 624 625 # Now go through all the commands. Compute the required stash size on the 626 # fly. If a command requires excess stash than available, it deletes the 627 # stash by replacing the command that uses the stash with a "new" command 628 # instead. 629 for xf in self.transfers: 630 replaced_cmds = [] 631 632 # xf.stash_before generates explicit stash commands. 633 for idx, sr in xf.stash_before: 634 if stashed_blocks + sr.size() > max_allowed: 635 # We cannot stash this one for a later command. Find out the command 636 # that will use this stash and replace the command with "new". 637 use_cmd = stashes[idx][2] 638 replaced_cmds.append(use_cmd) 639 print("%10d %9s %s" % (sr.size(), "explicit", use_cmd)) 640 else: 641 stashed_blocks += sr.size() 642 643 # xf.use_stash generates free commands. 644 for _, sr in xf.use_stash: 645 stashed_blocks -= sr.size() 646 647 # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to 648 # ComputePatches(), they both have the style of "diff". 649 if xf.style == "diff" and self.version >= 3: 650 assert xf.tgt_ranges and xf.src_ranges 651 if xf.src_ranges.overlaps(xf.tgt_ranges): 652 if stashed_blocks + xf.src_ranges.size() > max_allowed: 653 replaced_cmds.append(xf) 654 print("%10d %9s %s" % (xf.src_ranges.size(), "implicit", xf)) 655 656 # Replace the commands in replaced_cmds with "new"s. 657 for cmd in replaced_cmds: 658 # It no longer uses any commands in "use_stash". Remove the def points 659 # for all those stashes. 660 for idx, sr in cmd.use_stash: 661 def_cmd = stashes[idx][1] 662 assert (idx, sr) in def_cmd.stash_before 663 def_cmd.stash_before.remove((idx, sr)) 664 665 # Add up blocks that violates space limit and print total number to 666 # screen later. 667 new_blocks += cmd.tgt_ranges.size() 668 cmd.ConvertToNew() 669 670 num_of_bytes = new_blocks * self.tgt.blocksize 671 print(" Total %d blocks (%d bytes) are packed as new blocks due to " 672 "insufficient cache size." % (new_blocks, num_of_bytes)) 673 674 def ComputePatches(self, prefix): 675 print("Reticulating splines...") 676 diff_q = [] 677 patch_num = 0 678 with open(prefix + ".new.dat", "wb") as new_f: 679 for xf in self.transfers: 680 if xf.style == "zero": 681 pass 682 elif xf.style == "new": 683 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 684 new_f.write(piece) 685 elif xf.style == "diff": 686 src = self.src.ReadRangeSet(xf.src_ranges) 687 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 688 689 # We can't compare src and tgt directly because they may have 690 # the same content but be broken up into blocks differently, eg: 691 # 692 # ["he", "llo"] vs ["h", "ello"] 693 # 694 # We want those to compare equal, ideally without having to 695 # actually concatenate the strings (these may be tens of 696 # megabytes). 697 698 src_sha1 = sha1() 699 for p in src: 700 src_sha1.update(p) 701 tgt_sha1 = sha1() 702 tgt_size = 0 703 for p in tgt: 704 tgt_sha1.update(p) 705 tgt_size += len(p) 706 707 if src_sha1.digest() == tgt_sha1.digest(): 708 # These are identical; we don't need to generate a patch, 709 # just issue copy commands on the device. 710 xf.style = "move" 711 else: 712 # For files in zip format (eg, APKs, JARs, etc.) we would 713 # like to use imgdiff -z if possible (because it usually 714 # produces significantly smaller patches than bsdiff). 715 # This is permissible if: 716 # 717 # - the source and target files are monotonic (ie, the 718 # data is stored with blocks in increasing order), and 719 # - we haven't removed any blocks from the source set. 720 # 721 # If these conditions are satisfied then appending all the 722 # blocks in the set together in order will produce a valid 723 # zip file (plus possibly extra zeros in the last block), 724 # which is what imgdiff needs to operate. (imgdiff is 725 # fine with extra zeros at the end of the file.) 726 imgdiff = (xf.intact and 727 xf.tgt_name.split(".")[-1].lower() 728 in ("apk", "jar", "zip")) 729 xf.style = "imgdiff" if imgdiff else "bsdiff" 730 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 731 patch_num += 1 732 733 else: 734 assert False, "unknown style " + xf.style 735 736 if diff_q: 737 if self.threads > 1: 738 print("Computing patches (using %d threads)..." % (self.threads,)) 739 else: 740 print("Computing patches...") 741 diff_q.sort() 742 743 patches = [None] * patch_num 744 745 # TODO: Rewrite with multiprocessing.ThreadPool? 746 lock = threading.Lock() 747 def diff_worker(): 748 while True: 749 with lock: 750 if not diff_q: 751 return 752 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 753 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 754 size = len(patch) 755 with lock: 756 patches[patchnum] = (patch, xf) 757 print("%10d %10d (%6.2f%%) %7s %s" % ( 758 size, tgt_size, size * 100.0 / tgt_size, xf.style, 759 xf.tgt_name if xf.tgt_name == xf.src_name else ( 760 xf.tgt_name + " (from " + xf.src_name + ")"))) 761 762 threads = [threading.Thread(target=diff_worker) 763 for _ in range(self.threads)] 764 for th in threads: 765 th.start() 766 while threads: 767 threads.pop().join() 768 else: 769 patches = [] 770 771 p = 0 772 with open(prefix + ".patch.dat", "wb") as patch_f: 773 for patch, xf in patches: 774 xf.patch_start = p 775 xf.patch_len = len(patch) 776 patch_f.write(patch) 777 p += len(patch) 778 779 def AssertSequenceGood(self): 780 # Simulate the sequences of transfers we will output, and check that: 781 # - we never read a block after writing it, and 782 # - we write every block we care about exactly once. 783 784 # Start with no blocks having been touched yet. 785 touched = array.array("B", "\0" * self.tgt.total_blocks) 786 787 # Imagine processing the transfers in order. 788 for xf in self.transfers: 789 # Check that the input blocks for this transfer haven't yet been touched. 790 791 x = xf.src_ranges 792 if self.version >= 2: 793 for _, sr in xf.use_stash: 794 x = x.subtract(sr) 795 796 for s, e in x: 797 # Source image could be larger. Don't check the blocks that are in the 798 # source image only. Since they are not in 'touched', and won't ever 799 # be touched. 800 for i in range(s, min(e, self.tgt.total_blocks)): 801 assert touched[i] == 0 802 803 # Check that the output blocks for this transfer haven't yet 804 # been touched, and touch all the blocks written by this 805 # transfer. 806 for s, e in xf.tgt_ranges: 807 for i in range(s, e): 808 assert touched[i] == 0 809 touched[i] = 1 810 811 # Check that we've written every target block. 812 for s, e in self.tgt.care_map: 813 for i in range(s, e): 814 assert touched[i] == 1 815 816 def ImproveVertexSequence(self): 817 print("Improving vertex order...") 818 819 # At this point our digraph is acyclic; we reversed any edges that 820 # were backwards in the heuristically-generated sequence. The 821 # previously-generated order is still acceptable, but we hope to 822 # find a better order that needs less memory for stashed data. 823 # Now we do a topological sort to generate a new vertex order, 824 # using a greedy algorithm to choose which vertex goes next 825 # whenever we have a choice. 826 827 # Make a copy of the edge set; this copy will get destroyed by the 828 # algorithm. 829 for xf in self.transfers: 830 xf.incoming = xf.goes_after.copy() 831 xf.outgoing = xf.goes_before.copy() 832 833 L = [] # the new vertex order 834 835 # S is the set of sources in the remaining graph; we always choose 836 # the one that leaves the least amount of stashed data after it's 837 # executed. 838 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 839 if not u.incoming] 840 heapq.heapify(S) 841 842 while S: 843 _, _, xf = heapq.heappop(S) 844 L.append(xf) 845 for u in xf.outgoing: 846 del u.incoming[xf] 847 if not u.incoming: 848 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 849 850 # if this fails then our graph had a cycle. 851 assert len(L) == len(self.transfers) 852 853 self.transfers = L 854 for i, xf in enumerate(L): 855 xf.order = i 856 857 def RemoveBackwardEdges(self): 858 print("Removing backward edges...") 859 in_order = 0 860 out_of_order = 0 861 lost_source = 0 862 863 for xf in self.transfers: 864 lost = 0 865 size = xf.src_ranges.size() 866 for u in xf.goes_before: 867 # xf should go before u 868 if xf.order < u.order: 869 # it does, hurray! 870 in_order += 1 871 else: 872 # it doesn't, boo. trim the blocks that u writes from xf's 873 # source, so that xf can go after u. 874 out_of_order += 1 875 assert xf.src_ranges.overlaps(u.tgt_ranges) 876 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 877 xf.intact = False 878 879 if xf.style == "diff" and not xf.src_ranges: 880 # nothing left to diff from; treat as new data 881 xf.style = "new" 882 883 lost = size - xf.src_ranges.size() 884 lost_source += lost 885 886 print((" %d/%d dependencies (%.2f%%) were violated; " 887 "%d source blocks removed.") % 888 (out_of_order, in_order + out_of_order, 889 (out_of_order * 100.0 / (in_order + out_of_order)) 890 if (in_order + out_of_order) else 0.0, 891 lost_source)) 892 893 def ReverseBackwardEdges(self): 894 print("Reversing backward edges...") 895 in_order = 0 896 out_of_order = 0 897 stashes = 0 898 stash_size = 0 899 900 for xf in self.transfers: 901 for u in xf.goes_before.copy(): 902 # xf should go before u 903 if xf.order < u.order: 904 # it does, hurray! 905 in_order += 1 906 else: 907 # it doesn't, boo. modify u to stash the blocks that it 908 # writes that xf wants to read, and then require u to go 909 # before xf. 910 out_of_order += 1 911 912 overlap = xf.src_ranges.intersect(u.tgt_ranges) 913 assert overlap 914 915 u.stash_before.append((stashes, overlap)) 916 xf.use_stash.append((stashes, overlap)) 917 stashes += 1 918 stash_size += overlap.size() 919 920 # reverse the edge direction; now xf must go after u 921 del xf.goes_before[u] 922 del u.goes_after[xf] 923 xf.goes_after[u] = None # value doesn't matter 924 u.goes_before[xf] = None 925 926 print((" %d/%d dependencies (%.2f%%) were violated; " 927 "%d source blocks stashed.") % 928 (out_of_order, in_order + out_of_order, 929 (out_of_order * 100.0 / (in_order + out_of_order)) 930 if (in_order + out_of_order) else 0.0, 931 stash_size)) 932 933 def FindVertexSequence(self): 934 print("Finding vertex sequence...") 935 936 # This is based on "A Fast & Effective Heuristic for the Feedback 937 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 938 # it as starting with the digraph G and moving all the vertices to 939 # be on a horizontal line in some order, trying to minimize the 940 # number of edges that end up pointing to the left. Left-pointing 941 # edges will get removed to turn the digraph into a DAG. In this 942 # case each edge has a weight which is the number of source blocks 943 # we'll lose if that edge is removed; we try to minimize the total 944 # weight rather than just the number of edges. 945 946 # Make a copy of the edge set; this copy will get destroyed by the 947 # algorithm. 948 for xf in self.transfers: 949 xf.incoming = xf.goes_after.copy() 950 xf.outgoing = xf.goes_before.copy() 951 xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values()) 952 953 # We use an OrderedDict instead of just a set so that the output 954 # is repeatable; otherwise it would depend on the hash values of 955 # the transfer objects. 956 G = OrderedDict() 957 for xf in self.transfers: 958 G[xf] = None 959 s1 = deque() # the left side of the sequence, built from left to right 960 s2 = deque() # the right side of the sequence, built from right to left 961 962 heap = [] 963 for xf in self.transfers: 964 xf.heap_item = HeapItem(xf) 965 heap.append(xf.heap_item) 966 heapq.heapify(heap) 967 968 sinks = set(u for u in G if not u.outgoing) 969 sources = set(u for u in G if not u.incoming) 970 971 def adjust_score(iu, delta): 972 iu.score += delta 973 iu.heap_item.clear() 974 iu.heap_item = HeapItem(iu) 975 heapq.heappush(heap, iu.heap_item) 976 977 while G: 978 # Put all sinks at the end of the sequence. 979 while sinks: 980 new_sinks = set() 981 for u in sinks: 982 if u not in G: continue 983 s2.appendleft(u) 984 del G[u] 985 for iu in u.incoming: 986 adjust_score(iu, -iu.outgoing.pop(u)) 987 if not iu.outgoing: new_sinks.add(iu) 988 sinks = new_sinks 989 990 # Put all the sources at the beginning of the sequence. 991 while sources: 992 new_sources = set() 993 for u in sources: 994 if u not in G: continue 995 s1.append(u) 996 del G[u] 997 for iu in u.outgoing: 998 adjust_score(iu, +iu.incoming.pop(u)) 999 if not iu.incoming: new_sources.add(iu) 1000 sources = new_sources 1001 1002 if not G: break 1003 1004 # Find the "best" vertex to put next. "Best" is the one that 1005 # maximizes the net difference in source blocks saved we get by 1006 # pretending it's a source rather than a sink. 1007 1008 while True: 1009 u = heapq.heappop(heap) 1010 if u and u.item in G: 1011 u = u.item 1012 break 1013 1014 s1.append(u) 1015 del G[u] 1016 for iu in u.outgoing: 1017 adjust_score(iu, +iu.incoming.pop(u)) 1018 if not iu.incoming: sources.add(iu) 1019 1020 for iu in u.incoming: 1021 adjust_score(iu, -iu.outgoing.pop(u)) 1022 if not iu.outgoing: sinks.add(iu) 1023 1024 # Now record the sequence in the 'order' field of each transfer, 1025 # and by rearranging self.transfers to be in the chosen sequence. 1026 1027 new_transfers = [] 1028 for x in itertools.chain(s1, s2): 1029 x.order = len(new_transfers) 1030 new_transfers.append(x) 1031 del x.incoming 1032 del x.outgoing 1033 1034 self.transfers = new_transfers 1035 1036 def GenerateDigraph(self): 1037 print("Generating digraph...") 1038 1039 # Each item of source_ranges will be: 1040 # - None, if that block is not used as a source, 1041 # - a transfer, if one transfer uses it as a source, or 1042 # - a set of transfers. 1043 source_ranges = [] 1044 for b in self.transfers: 1045 for s, e in b.src_ranges: 1046 if e > len(source_ranges): 1047 source_ranges.extend([None] * (e-len(source_ranges))) 1048 for i in range(s, e): 1049 if source_ranges[i] is None: 1050 source_ranges[i] = b 1051 else: 1052 if not isinstance(source_ranges[i], set): 1053 source_ranges[i] = set([source_ranges[i]]) 1054 source_ranges[i].add(b) 1055 1056 for a in self.transfers: 1057 intersections = set() 1058 for s, e in a.tgt_ranges: 1059 for i in range(s, e): 1060 if i >= len(source_ranges): break 1061 b = source_ranges[i] 1062 if b is not None: 1063 if isinstance(b, set): 1064 intersections.update(b) 1065 else: 1066 intersections.add(b) 1067 1068 for b in intersections: 1069 if a is b: continue 1070 1071 # If the blocks written by A are read by B, then B needs to go before A. 1072 i = a.tgt_ranges.intersect(b.src_ranges) 1073 if i: 1074 if b.src_name == "__ZERO": 1075 # the cost of removing source blocks for the __ZERO domain 1076 # is (nearly) zero. 1077 size = 0 1078 else: 1079 size = i.size() 1080 b.goes_before[a] = size 1081 a.goes_after[b] = size 1082 1083 def FindTransfers(self): 1084 """Parse the file_map to generate all the transfers.""" 1085 1086 def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id, 1087 split=False): 1088 """Wrapper function for adding a Transfer(). 1089 1090 For BBOTA v3, we need to stash source blocks for resumable feature. 1091 However, with the growth of file size and the shrink of the cache 1092 partition source blocks are too large to be stashed. If a file occupies 1093 too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it 1094 into smaller pieces by getting multiple Transfer()s. 1095 1096 The downside is that after splitting, we may increase the package size 1097 since the split pieces don't align well. According to our experiments, 1098 1/8 of the cache size as the per-piece limit appears to be optimal. 1099 Compared to the fixed 1024-block limit, it reduces the overall package 1100 size by 30% volantis, and 20% for angler and bullhead.""" 1101 1102 # We care about diff transfers only. 1103 if style != "diff" or not split: 1104 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 1105 return 1106 1107 pieces = 0 1108 cache_size = common.OPTIONS.cache_size 1109 split_threshold = 0.125 1110 max_blocks_per_transfer = int(cache_size * split_threshold / 1111 self.tgt.blocksize) 1112 1113 # Change nothing for small files. 1114 if (tgt_ranges.size() <= max_blocks_per_transfer and 1115 src_ranges.size() <= max_blocks_per_transfer): 1116 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 1117 return 1118 1119 while (tgt_ranges.size() > max_blocks_per_transfer and 1120 src_ranges.size() > max_blocks_per_transfer): 1121 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1122 src_split_name = "%s-%d" % (src_name, pieces) 1123 tgt_first = tgt_ranges.first(max_blocks_per_transfer) 1124 src_first = src_ranges.first(max_blocks_per_transfer) 1125 1126 Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style, 1127 by_id) 1128 1129 tgt_ranges = tgt_ranges.subtract(tgt_first) 1130 src_ranges = src_ranges.subtract(src_first) 1131 pieces += 1 1132 1133 # Handle remaining blocks. 1134 if tgt_ranges.size() or src_ranges.size(): 1135 # Must be both non-empty. 1136 assert tgt_ranges.size() and src_ranges.size() 1137 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1138 src_split_name = "%s-%d" % (src_name, pieces) 1139 Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style, 1140 by_id) 1141 1142 empty = RangeSet() 1143 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 1144 if tgt_fn == "__ZERO": 1145 # the special "__ZERO" domain is all the blocks not contained 1146 # in any file and that are filled with zeros. We have a 1147 # special transfer style for zero blocks. 1148 src_ranges = self.src.file_map.get("__ZERO", empty) 1149 AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 1150 "zero", self.transfers) 1151 continue 1152 1153 elif tgt_fn == "__COPY": 1154 # "__COPY" domain includes all the blocks not contained in any 1155 # file and that need to be copied unconditionally to the target. 1156 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1157 continue 1158 1159 elif tgt_fn in self.src.file_map: 1160 # Look for an exact pathname match in the source. 1161 AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 1162 "diff", self.transfers, self.version >= 3) 1163 continue 1164 1165 b = os.path.basename(tgt_fn) 1166 if b in self.src_basenames: 1167 # Look for an exact basename match in the source. 1168 src_fn = self.src_basenames[b] 1169 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1170 "diff", self.transfers, self.version >= 3) 1171 continue 1172 1173 b = re.sub("[0-9]+", "#", b) 1174 if b in self.src_numpatterns: 1175 # Look for a 'number pattern' match (a basename match after 1176 # all runs of digits are replaced by "#"). (This is useful 1177 # for .so files that contain version numbers in the filename 1178 # that get bumped.) 1179 src_fn = self.src_numpatterns[b] 1180 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1181 "diff", self.transfers, self.version >= 3) 1182 continue 1183 1184 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1185 1186 def AbbreviateSourceNames(self): 1187 for k in self.src.file_map.keys(): 1188 b = os.path.basename(k) 1189 self.src_basenames[b] = k 1190 b = re.sub("[0-9]+", "#", b) 1191 self.src_numpatterns[b] = k 1192 1193 @staticmethod 1194 def AssertPartition(total, seq): 1195 """Assert that all the RangeSets in 'seq' form a partition of the 1196 'total' RangeSet (ie, they are nonintersecting and their union 1197 equals 'total').""" 1198 1199 so_far = RangeSet() 1200 for i in seq: 1201 assert not so_far.overlaps(i) 1202 so_far = so_far.union(i) 1203 assert so_far == total 1204