blockimgdiff.py revision 21b37d83ebbebb461baa6bdc683f8920e7666ae5
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import common 20import heapq 21import itertools 22import multiprocessing 23import os 24import re 25import subprocess 26import threading 27import tempfile 28 29from rangelib import RangeSet 30 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34 35def compute_patch(src, tgt, imgdiff=False): 36 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 37 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 38 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 39 os.close(patchfd) 40 41 try: 42 with os.fdopen(srcfd, "wb") as f_src: 43 for p in src: 44 f_src.write(p) 45 46 with os.fdopen(tgtfd, "wb") as f_tgt: 47 for p in tgt: 48 f_tgt.write(p) 49 try: 50 os.unlink(patchfile) 51 except OSError: 52 pass 53 if imgdiff: 54 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 55 stdout=open("/dev/null", "a"), 56 stderr=subprocess.STDOUT) 57 else: 58 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 59 60 if p: 61 raise ValueError("diff failed: " + str(p)) 62 63 with open(patchfile, "rb") as f: 64 return f.read() 65 finally: 66 try: 67 os.unlink(srcfile) 68 os.unlink(tgtfile) 69 os.unlink(patchfile) 70 except OSError: 71 pass 72 73 74class Image(object): 75 def ReadRangeSet(self, ranges): 76 raise NotImplementedError 77 78 def TotalSha1(self, include_clobbered_blocks=False): 79 raise NotImplementedError 80 81 82class EmptyImage(Image): 83 """A zero-length image.""" 84 blocksize = 4096 85 care_map = RangeSet() 86 clobbered_blocks = RangeSet() 87 extended = RangeSet() 88 total_blocks = 0 89 file_map = {} 90 def ReadRangeSet(self, ranges): 91 return () 92 def TotalSha1(self, include_clobbered_blocks=False): 93 # EmptyImage always carries empty clobbered_blocks, so 94 # include_clobbered_blocks can be ignored. 95 assert self.clobbered_blocks.size() == 0 96 return sha1().hexdigest() 97 98 99class DataImage(Image): 100 """An image wrapped around a single string of data.""" 101 102 def __init__(self, data, trim=False, pad=False): 103 self.data = data 104 self.blocksize = 4096 105 106 assert not (trim and pad) 107 108 partial = len(self.data) % self.blocksize 109 padded = False 110 if partial > 0: 111 if trim: 112 self.data = self.data[:-partial] 113 elif pad: 114 self.data += '\0' * (self.blocksize - partial) 115 padded = True 116 else: 117 raise ValueError(("data for DataImage must be multiple of %d bytes " 118 "unless trim or pad is specified") % 119 (self.blocksize,)) 120 121 assert len(self.data) % self.blocksize == 0 122 123 self.total_blocks = len(self.data) / self.blocksize 124 self.care_map = RangeSet(data=(0, self.total_blocks)) 125 # When the last block is padded, we always write the whole block even for 126 # incremental OTAs. Because otherwise the last block may get skipped if 127 # unchanged for an incremental, but would fail the post-install 128 # verification if it has non-zero contents in the padding bytes. 129 # Bug: 23828506 130 if padded: 131 clobbered_blocks = [self.total_blocks-1, self.total_blocks] 132 else: 133 clobbered_blocks = [] 134 self.clobbered_blocks = clobbered_blocks 135 self.extended = RangeSet() 136 137 zero_blocks = [] 138 nonzero_blocks = [] 139 reference = '\0' * self.blocksize 140 141 for i in range(self.total_blocks-1 if padded else self.total_blocks): 142 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 143 if d == reference: 144 zero_blocks.append(i) 145 zero_blocks.append(i+1) 146 else: 147 nonzero_blocks.append(i) 148 nonzero_blocks.append(i+1) 149 150 assert zero_blocks or nonzero_blocks or clobbered_blocks 151 152 self.file_map = dict() 153 if zero_blocks: 154 self.file_map["__ZERO"] = RangeSet(data=zero_blocks) 155 if nonzero_blocks: 156 self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) 157 if clobbered_blocks: 158 self.file_map["__COPY"] = RangeSet(data=clobbered_blocks) 159 160 def ReadRangeSet(self, ranges): 161 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 162 163 def TotalSha1(self, include_clobbered_blocks=False): 164 if not include_clobbered_blocks: 165 ranges = self.care_map.subtract(self.clobbered_blocks) 166 return sha1(self.ReadRangeSet(ranges)).hexdigest() 167 else: 168 return sha1(self.data).hexdigest() 169 170 171class Transfer(object): 172 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 173 self.tgt_name = tgt_name 174 self.src_name = src_name 175 self.tgt_ranges = tgt_ranges 176 self.src_ranges = src_ranges 177 self.style = style 178 self.intact = (getattr(tgt_ranges, "monotonic", False) and 179 getattr(src_ranges, "monotonic", False)) 180 181 # We use OrderedDict rather than dict so that the output is repeatable; 182 # otherwise it would depend on the hash values of the Transfer objects. 183 self.goes_before = OrderedDict() 184 self.goes_after = OrderedDict() 185 186 self.stash_before = [] 187 self.use_stash = [] 188 189 self.id = len(by_id) 190 by_id.append(self) 191 192 def NetStashChange(self): 193 return (sum(sr.size() for (_, sr) in self.stash_before) - 194 sum(sr.size() for (_, sr) in self.use_stash)) 195 196 def ConvertToNew(self): 197 assert self.style != "new" 198 self.use_stash = [] 199 self.style = "new" 200 self.src_ranges = RangeSet() 201 202 def __str__(self): 203 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 204 " to " + str(self.tgt_ranges) + ">") 205 206 207# BlockImageDiff works on two image objects. An image object is 208# anything that provides the following attributes: 209# 210# blocksize: the size in bytes of a block, currently must be 4096. 211# 212# total_blocks: the total size of the partition/image, in blocks. 213# 214# care_map: a RangeSet containing which blocks (in the range [0, 215# total_blocks) we actually care about; i.e. which blocks contain 216# data. 217# 218# file_map: a dict that partitions the blocks contained in care_map 219# into smaller domains that are useful for doing diffs on. 220# (Typically a domain is a file, and the key in file_map is the 221# pathname.) 222# 223# clobbered_blocks: a RangeSet containing which blocks contain data 224# but may be altered by the FS. They need to be excluded when 225# verifying the partition integrity. 226# 227# ReadRangeSet(): a function that takes a RangeSet and returns the 228# data contained in the image blocks of that RangeSet. The data 229# is returned as a list or tuple of strings; concatenating the 230# elements together should produce the requested data. 231# Implementations are free to break up the data into list/tuple 232# elements in any way that is convenient. 233# 234# TotalSha1(): a function that returns (as a hex string) the SHA-1 235# hash of all the data in the image (ie, all the blocks in the 236# care_map minus clobbered_blocks, or including the clobbered 237# blocks if include_clobbered_blocks is True). 238# 239# When creating a BlockImageDiff, the src image may be None, in which 240# case the list of transfers produced will never read from the 241# original image. 242 243class BlockImageDiff(object): 244 def __init__(self, tgt, src=None, threads=None, version=3): 245 if threads is None: 246 threads = multiprocessing.cpu_count() // 2 247 if threads == 0: 248 threads = 1 249 self.threads = threads 250 self.version = version 251 self.transfers = [] 252 self.src_basenames = {} 253 self.src_numpatterns = {} 254 255 assert version in (1, 2, 3) 256 257 self.tgt = tgt 258 if src is None: 259 src = EmptyImage() 260 self.src = src 261 262 # The updater code that installs the patch always uses 4k blocks. 263 assert tgt.blocksize == 4096 264 assert src.blocksize == 4096 265 266 # The range sets in each filemap should comprise a partition of 267 # the care map. 268 self.AssertPartition(src.care_map, src.file_map.values()) 269 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 270 271 def Compute(self, prefix): 272 # When looking for a source file to use as the diff input for a 273 # target file, we try: 274 # 1) an exact path match if available, otherwise 275 # 2) a exact basename match if available, otherwise 276 # 3) a basename match after all runs of digits are replaced by 277 # "#" if available, otherwise 278 # 4) we have no source for this target. 279 self.AbbreviateSourceNames() 280 self.FindTransfers() 281 282 # Find the ordering dependencies among transfers (this is O(n^2) 283 # in the number of transfers). 284 self.GenerateDigraph() 285 # Find a sequence of transfers that satisfies as many ordering 286 # dependencies as possible (heuristically). 287 self.FindVertexSequence() 288 # Fix up the ordering dependencies that the sequence didn't 289 # satisfy. 290 if self.version == 1: 291 self.RemoveBackwardEdges() 292 else: 293 self.ReverseBackwardEdges() 294 self.ImproveVertexSequence() 295 296 # Ensure the runtime stash size is under the limit. 297 if self.version >= 2 and common.OPTIONS.cache_size is not None: 298 self.ReviseStashSize() 299 300 # Double-check our work. 301 self.AssertSequenceGood() 302 303 self.ComputePatches(prefix) 304 self.WriteTransfers(prefix) 305 306 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 307 data = source.ReadRangeSet(ranges) 308 ctx = sha1() 309 310 for p in data: 311 ctx.update(p) 312 313 return ctx.hexdigest() 314 315 def WriteTransfers(self, prefix): 316 out = [] 317 318 total = 0 319 320 stashes = {} 321 stashed_blocks = 0 322 max_stashed_blocks = 0 323 324 free_stash_ids = [] 325 next_stash_id = 0 326 327 for xf in self.transfers: 328 329 if self.version < 2: 330 assert not xf.stash_before 331 assert not xf.use_stash 332 333 for s, sr in xf.stash_before: 334 assert s not in stashes 335 if free_stash_ids: 336 sid = heapq.heappop(free_stash_ids) 337 else: 338 sid = next_stash_id 339 next_stash_id += 1 340 stashes[s] = sid 341 if self.version == 2: 342 stashed_blocks += sr.size() 343 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 344 else: 345 sh = self.HashBlocks(self.src, sr) 346 if sh in stashes: 347 stashes[sh] += 1 348 else: 349 stashes[sh] = 1 350 stashed_blocks += sr.size() 351 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 352 353 if stashed_blocks > max_stashed_blocks: 354 max_stashed_blocks = stashed_blocks 355 356 free_string = [] 357 free_size = 0 358 359 if self.version == 1: 360 src_str = xf.src_ranges.to_string_raw() 361 elif self.version >= 2: 362 363 # <# blocks> <src ranges> 364 # OR 365 # <# blocks> <src ranges> <src locs> <stash refs...> 366 # OR 367 # <# blocks> - <stash refs...> 368 369 size = xf.src_ranges.size() 370 src_str = [str(size)] 371 372 unstashed_src_ranges = xf.src_ranges 373 mapped_stashes = [] 374 for s, sr in xf.use_stash: 375 sid = stashes.pop(s) 376 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 377 sh = self.HashBlocks(self.src, sr) 378 sr = xf.src_ranges.map_within(sr) 379 mapped_stashes.append(sr) 380 if self.version == 2: 381 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 382 # A stash will be used only once. We need to free the stash 383 # immediately after the use, instead of waiting for the automatic 384 # clean-up at the end. Because otherwise it may take up extra space 385 # and lead to OTA failures. 386 # Bug: 23119955 387 free_string.append("free %d\n" % (sid,)) 388 free_size += sr.size() 389 else: 390 assert sh in stashes 391 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 392 stashes[sh] -= 1 393 if stashes[sh] == 0: 394 free_size += sr.size() 395 free_string.append("free %s\n" % (sh)) 396 stashes.pop(sh) 397 heapq.heappush(free_stash_ids, sid) 398 399 if unstashed_src_ranges: 400 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 401 if xf.use_stash: 402 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 403 src_str.insert(2, mapped_unstashed.to_string_raw()) 404 mapped_stashes.append(mapped_unstashed) 405 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 406 else: 407 src_str.insert(1, "-") 408 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 409 410 src_str = " ".join(src_str) 411 412 # all versions: 413 # zero <rangeset> 414 # new <rangeset> 415 # erase <rangeset> 416 # 417 # version 1: 418 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 419 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 420 # move <src rangeset> <tgt rangeset> 421 # 422 # version 2: 423 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 424 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 425 # move <tgt rangeset> <src_str> 426 # 427 # version 3: 428 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 429 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 430 # move hash <tgt rangeset> <src_str> 431 432 tgt_size = xf.tgt_ranges.size() 433 434 if xf.style == "new": 435 assert xf.tgt_ranges 436 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 437 total += tgt_size 438 elif xf.style == "move": 439 assert xf.tgt_ranges 440 assert xf.src_ranges.size() == tgt_size 441 if xf.src_ranges != xf.tgt_ranges: 442 if self.version == 1: 443 out.append("%s %s %s\n" % ( 444 xf.style, 445 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 446 elif self.version == 2: 447 out.append("%s %s %s\n" % ( 448 xf.style, 449 xf.tgt_ranges.to_string_raw(), src_str)) 450 elif self.version >= 3: 451 # take into account automatic stashing of overlapping blocks 452 if xf.src_ranges.overlaps(xf.tgt_ranges): 453 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 454 if temp_stash_usage > max_stashed_blocks: 455 max_stashed_blocks = temp_stash_usage 456 457 out.append("%s %s %s %s\n" % ( 458 xf.style, 459 self.HashBlocks(self.tgt, xf.tgt_ranges), 460 xf.tgt_ranges.to_string_raw(), src_str)) 461 total += tgt_size 462 elif xf.style in ("bsdiff", "imgdiff"): 463 assert xf.tgt_ranges 464 assert xf.src_ranges 465 if self.version == 1: 466 out.append("%s %d %d %s %s\n" % ( 467 xf.style, xf.patch_start, xf.patch_len, 468 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 469 elif self.version == 2: 470 out.append("%s %d %d %s %s\n" % ( 471 xf.style, xf.patch_start, xf.patch_len, 472 xf.tgt_ranges.to_string_raw(), src_str)) 473 elif self.version >= 3: 474 # take into account automatic stashing of overlapping blocks 475 if xf.src_ranges.overlaps(xf.tgt_ranges): 476 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 477 if temp_stash_usage > max_stashed_blocks: 478 max_stashed_blocks = temp_stash_usage 479 480 out.append("%s %d %d %s %s %s %s\n" % ( 481 xf.style, 482 xf.patch_start, xf.patch_len, 483 self.HashBlocks(self.src, xf.src_ranges), 484 self.HashBlocks(self.tgt, xf.tgt_ranges), 485 xf.tgt_ranges.to_string_raw(), src_str)) 486 total += tgt_size 487 elif xf.style == "zero": 488 assert xf.tgt_ranges 489 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 490 if to_zero: 491 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 492 total += to_zero.size() 493 else: 494 raise ValueError("unknown transfer style '%s'\n" % xf.style) 495 496 if free_string: 497 out.append("".join(free_string)) 498 stashed_blocks -= free_size 499 500 if self.version >= 2 and common.OPTIONS.cache_size is not None: 501 # Sanity check: abort if we're going to need more stash space than 502 # the allowed size (cache_size * threshold). There are two purposes 503 # of having a threshold here. a) Part of the cache may have been 504 # occupied by some recovery logs. b) It will buy us some time to deal 505 # with the oversize issue. 506 cache_size = common.OPTIONS.cache_size 507 stash_threshold = common.OPTIONS.stash_threshold 508 max_allowed = cache_size * stash_threshold 509 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 510 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 511 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 512 self.tgt.blocksize, max_allowed, cache_size, 513 stash_threshold) 514 515 # Zero out extended blocks as a workaround for bug 20881595. 516 if self.tgt.extended: 517 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 518 total += self.tgt.extended.size() 519 520 # We erase all the blocks on the partition that a) don't contain useful 521 # data in the new image and b) will not be touched by dm-verity. 522 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 523 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 524 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 525 if new_dontcare: 526 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 527 528 out.insert(0, "%d\n" % (self.version,)) # format version number 529 out.insert(1, "%d\n" % (total,)) 530 if self.version >= 2: 531 # version 2 only: after the total block count, we give the number 532 # of stash slots needed, and the maximum size needed (in blocks) 533 out.insert(2, str(next_stash_id) + "\n") 534 out.insert(3, str(max_stashed_blocks) + "\n") 535 536 with open(prefix + ".transfer.list", "wb") as f: 537 for i in out: 538 f.write(i) 539 540 if self.version >= 2: 541 max_stashed_size = max_stashed_blocks * self.tgt.blocksize 542 OPTIONS = common.OPTIONS 543 if OPTIONS.cache_size is not None: 544 max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold 545 print("max stashed blocks: %d (%d bytes), " 546 "limit: %d bytes (%.2f%%)\n" % ( 547 max_stashed_blocks, max_stashed_size, max_allowed, 548 max_stashed_size * 100.0 / max_allowed)) 549 else: 550 print("max stashed blocks: %d (%d bytes), limit: <unknown>\n" % ( 551 max_stashed_blocks, max_stashed_size)) 552 553 def ReviseStashSize(self): 554 print("Revising stash size...") 555 stashes = {} 556 557 # Create the map between a stash and its def/use points. For example, for a 558 # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd). 559 for xf in self.transfers: 560 # Command xf defines (stores) all the stashes in stash_before. 561 for idx, sr in xf.stash_before: 562 stashes[idx] = (sr, xf) 563 564 # Record all the stashes command xf uses. 565 for idx, _ in xf.use_stash: 566 stashes[idx] += (xf,) 567 568 # Compute the maximum blocks available for stash based on /cache size and 569 # the threshold. 570 cache_size = common.OPTIONS.cache_size 571 stash_threshold = common.OPTIONS.stash_threshold 572 max_allowed = cache_size * stash_threshold / self.tgt.blocksize 573 574 stashed_blocks = 0 575 new_blocks = 0 576 577 # Now go through all the commands. Compute the required stash size on the 578 # fly. If a command requires excess stash than available, it deletes the 579 # stash by replacing the command that uses the stash with a "new" command 580 # instead. 581 for xf in self.transfers: 582 replaced_cmds = [] 583 584 # xf.stash_before generates explicit stash commands. 585 for idx, sr in xf.stash_before: 586 if stashed_blocks + sr.size() > max_allowed: 587 # We cannot stash this one for a later command. Find out the command 588 # that will use this stash and replace the command with "new". 589 use_cmd = stashes[idx][2] 590 replaced_cmds.append(use_cmd) 591 print("%10d %9s %s" % (sr.size(), "explicit", use_cmd)) 592 else: 593 stashed_blocks += sr.size() 594 595 # xf.use_stash generates free commands. 596 for _, sr in xf.use_stash: 597 stashed_blocks -= sr.size() 598 599 # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to 600 # ComputePatches(), they both have the style of "diff". 601 if xf.style == "diff" and self.version >= 3: 602 assert xf.tgt_ranges and xf.src_ranges 603 if xf.src_ranges.overlaps(xf.tgt_ranges): 604 if stashed_blocks + xf.src_ranges.size() > max_allowed: 605 replaced_cmds.append(xf) 606 print("%10d %9s %s" % (xf.src_ranges.size(), "implicit", xf)) 607 608 # Replace the commands in replaced_cmds with "new"s. 609 for cmd in replaced_cmds: 610 # It no longer uses any commands in "use_stash". Remove the def points 611 # for all those stashes. 612 for idx, sr in cmd.use_stash: 613 def_cmd = stashes[idx][1] 614 assert (idx, sr) in def_cmd.stash_before 615 def_cmd.stash_before.remove((idx, sr)) 616 new_blocks += sr.size() 617 618 cmd.ConvertToNew() 619 620 print(" Total %d blocks are packed as new blocks due to insufficient " 621 "cache size." % (new_blocks,)) 622 623 def ComputePatches(self, prefix): 624 print("Reticulating splines...") 625 diff_q = [] 626 patch_num = 0 627 with open(prefix + ".new.dat", "wb") as new_f: 628 for xf in self.transfers: 629 if xf.style == "zero": 630 pass 631 elif xf.style == "new": 632 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 633 new_f.write(piece) 634 elif xf.style == "diff": 635 src = self.src.ReadRangeSet(xf.src_ranges) 636 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 637 638 # We can't compare src and tgt directly because they may have 639 # the same content but be broken up into blocks differently, eg: 640 # 641 # ["he", "llo"] vs ["h", "ello"] 642 # 643 # We want those to compare equal, ideally without having to 644 # actually concatenate the strings (these may be tens of 645 # megabytes). 646 647 src_sha1 = sha1() 648 for p in src: 649 src_sha1.update(p) 650 tgt_sha1 = sha1() 651 tgt_size = 0 652 for p in tgt: 653 tgt_sha1.update(p) 654 tgt_size += len(p) 655 656 if src_sha1.digest() == tgt_sha1.digest(): 657 # These are identical; we don't need to generate a patch, 658 # just issue copy commands on the device. 659 xf.style = "move" 660 else: 661 # For files in zip format (eg, APKs, JARs, etc.) we would 662 # like to use imgdiff -z if possible (because it usually 663 # produces significantly smaller patches than bsdiff). 664 # This is permissible if: 665 # 666 # - the source and target files are monotonic (ie, the 667 # data is stored with blocks in increasing order), and 668 # - we haven't removed any blocks from the source set. 669 # 670 # If these conditions are satisfied then appending all the 671 # blocks in the set together in order will produce a valid 672 # zip file (plus possibly extra zeros in the last block), 673 # which is what imgdiff needs to operate. (imgdiff is 674 # fine with extra zeros at the end of the file.) 675 imgdiff = (xf.intact and 676 xf.tgt_name.split(".")[-1].lower() 677 in ("apk", "jar", "zip")) 678 xf.style = "imgdiff" if imgdiff else "bsdiff" 679 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 680 patch_num += 1 681 682 else: 683 assert False, "unknown style " + xf.style 684 685 if diff_q: 686 if self.threads > 1: 687 print("Computing patches (using %d threads)..." % (self.threads,)) 688 else: 689 print("Computing patches...") 690 diff_q.sort() 691 692 patches = [None] * patch_num 693 694 # TODO: Rewrite with multiprocessing.ThreadPool? 695 lock = threading.Lock() 696 def diff_worker(): 697 while True: 698 with lock: 699 if not diff_q: 700 return 701 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 702 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 703 size = len(patch) 704 with lock: 705 patches[patchnum] = (patch, xf) 706 print("%10d %10d (%6.2f%%) %7s %s" % ( 707 size, tgt_size, size * 100.0 / tgt_size, xf.style, 708 xf.tgt_name if xf.tgt_name == xf.src_name else ( 709 xf.tgt_name + " (from " + xf.src_name + ")"))) 710 711 threads = [threading.Thread(target=diff_worker) 712 for _ in range(self.threads)] 713 for th in threads: 714 th.start() 715 while threads: 716 threads.pop().join() 717 else: 718 patches = [] 719 720 p = 0 721 with open(prefix + ".patch.dat", "wb") as patch_f: 722 for patch, xf in patches: 723 xf.patch_start = p 724 xf.patch_len = len(patch) 725 patch_f.write(patch) 726 p += len(patch) 727 728 def AssertSequenceGood(self): 729 # Simulate the sequences of transfers we will output, and check that: 730 # - we never read a block after writing it, and 731 # - we write every block we care about exactly once. 732 733 # Start with no blocks having been touched yet. 734 touched = RangeSet() 735 736 # Imagine processing the transfers in order. 737 for xf in self.transfers: 738 # Check that the input blocks for this transfer haven't yet been touched. 739 740 x = xf.src_ranges 741 if self.version >= 2: 742 for _, sr in xf.use_stash: 743 x = x.subtract(sr) 744 745 assert not touched.overlaps(x) 746 # Check that the output blocks for this transfer haven't yet been touched. 747 assert not touched.overlaps(xf.tgt_ranges) 748 # Touch all the blocks written by this transfer. 749 touched = touched.union(xf.tgt_ranges) 750 751 # Check that we've written every target block. 752 assert touched == self.tgt.care_map 753 754 def ImproveVertexSequence(self): 755 print("Improving vertex order...") 756 757 # At this point our digraph is acyclic; we reversed any edges that 758 # were backwards in the heuristically-generated sequence. The 759 # previously-generated order is still acceptable, but we hope to 760 # find a better order that needs less memory for stashed data. 761 # Now we do a topological sort to generate a new vertex order, 762 # using a greedy algorithm to choose which vertex goes next 763 # whenever we have a choice. 764 765 # Make a copy of the edge set; this copy will get destroyed by the 766 # algorithm. 767 for xf in self.transfers: 768 xf.incoming = xf.goes_after.copy() 769 xf.outgoing = xf.goes_before.copy() 770 771 L = [] # the new vertex order 772 773 # S is the set of sources in the remaining graph; we always choose 774 # the one that leaves the least amount of stashed data after it's 775 # executed. 776 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 777 if not u.incoming] 778 heapq.heapify(S) 779 780 while S: 781 _, _, xf = heapq.heappop(S) 782 L.append(xf) 783 for u in xf.outgoing: 784 del u.incoming[xf] 785 if not u.incoming: 786 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 787 788 # if this fails then our graph had a cycle. 789 assert len(L) == len(self.transfers) 790 791 self.transfers = L 792 for i, xf in enumerate(L): 793 xf.order = i 794 795 def RemoveBackwardEdges(self): 796 print("Removing backward edges...") 797 in_order = 0 798 out_of_order = 0 799 lost_source = 0 800 801 for xf in self.transfers: 802 lost = 0 803 size = xf.src_ranges.size() 804 for u in xf.goes_before: 805 # xf should go before u 806 if xf.order < u.order: 807 # it does, hurray! 808 in_order += 1 809 else: 810 # it doesn't, boo. trim the blocks that u writes from xf's 811 # source, so that xf can go after u. 812 out_of_order += 1 813 assert xf.src_ranges.overlaps(u.tgt_ranges) 814 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 815 xf.intact = False 816 817 if xf.style == "diff" and not xf.src_ranges: 818 # nothing left to diff from; treat as new data 819 xf.style = "new" 820 821 lost = size - xf.src_ranges.size() 822 lost_source += lost 823 824 print((" %d/%d dependencies (%.2f%%) were violated; " 825 "%d source blocks removed.") % 826 (out_of_order, in_order + out_of_order, 827 (out_of_order * 100.0 / (in_order + out_of_order)) 828 if (in_order + out_of_order) else 0.0, 829 lost_source)) 830 831 def ReverseBackwardEdges(self): 832 print("Reversing backward edges...") 833 in_order = 0 834 out_of_order = 0 835 stashes = 0 836 stash_size = 0 837 838 for xf in self.transfers: 839 for u in xf.goes_before.copy(): 840 # xf should go before u 841 if xf.order < u.order: 842 # it does, hurray! 843 in_order += 1 844 else: 845 # it doesn't, boo. modify u to stash the blocks that it 846 # writes that xf wants to read, and then require u to go 847 # before xf. 848 out_of_order += 1 849 850 overlap = xf.src_ranges.intersect(u.tgt_ranges) 851 assert overlap 852 853 u.stash_before.append((stashes, overlap)) 854 xf.use_stash.append((stashes, overlap)) 855 stashes += 1 856 stash_size += overlap.size() 857 858 # reverse the edge direction; now xf must go after u 859 del xf.goes_before[u] 860 del u.goes_after[xf] 861 xf.goes_after[u] = None # value doesn't matter 862 u.goes_before[xf] = None 863 864 print((" %d/%d dependencies (%.2f%%) were violated; " 865 "%d source blocks stashed.") % 866 (out_of_order, in_order + out_of_order, 867 (out_of_order * 100.0 / (in_order + out_of_order)) 868 if (in_order + out_of_order) else 0.0, 869 stash_size)) 870 871 def FindVertexSequence(self): 872 print("Finding vertex sequence...") 873 874 # This is based on "A Fast & Effective Heuristic for the Feedback 875 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 876 # it as starting with the digraph G and moving all the vertices to 877 # be on a horizontal line in some order, trying to minimize the 878 # number of edges that end up pointing to the left. Left-pointing 879 # edges will get removed to turn the digraph into a DAG. In this 880 # case each edge has a weight which is the number of source blocks 881 # we'll lose if that edge is removed; we try to minimize the total 882 # weight rather than just the number of edges. 883 884 # Make a copy of the edge set; this copy will get destroyed by the 885 # algorithm. 886 for xf in self.transfers: 887 xf.incoming = xf.goes_after.copy() 888 xf.outgoing = xf.goes_before.copy() 889 890 # We use an OrderedDict instead of just a set so that the output 891 # is repeatable; otherwise it would depend on the hash values of 892 # the transfer objects. 893 G = OrderedDict() 894 for xf in self.transfers: 895 G[xf] = None 896 s1 = deque() # the left side of the sequence, built from left to right 897 s2 = deque() # the right side of the sequence, built from right to left 898 899 while G: 900 901 # Put all sinks at the end of the sequence. 902 while True: 903 sinks = [u for u in G if not u.outgoing] 904 if not sinks: 905 break 906 for u in sinks: 907 s2.appendleft(u) 908 del G[u] 909 for iu in u.incoming: 910 del iu.outgoing[u] 911 912 # Put all the sources at the beginning of the sequence. 913 while True: 914 sources = [u for u in G if not u.incoming] 915 if not sources: 916 break 917 for u in sources: 918 s1.append(u) 919 del G[u] 920 for iu in u.outgoing: 921 del iu.incoming[u] 922 923 if not G: 924 break 925 926 # Find the "best" vertex to put next. "Best" is the one that 927 # maximizes the net difference in source blocks saved we get by 928 # pretending it's a source rather than a sink. 929 930 max_d = None 931 best_u = None 932 for u in G: 933 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 934 if best_u is None or d > max_d: 935 max_d = d 936 best_u = u 937 938 u = best_u 939 s1.append(u) 940 del G[u] 941 for iu in u.outgoing: 942 del iu.incoming[u] 943 for iu in u.incoming: 944 del iu.outgoing[u] 945 946 # Now record the sequence in the 'order' field of each transfer, 947 # and by rearranging self.transfers to be in the chosen sequence. 948 949 new_transfers = [] 950 for x in itertools.chain(s1, s2): 951 x.order = len(new_transfers) 952 new_transfers.append(x) 953 del x.incoming 954 del x.outgoing 955 956 self.transfers = new_transfers 957 958 def GenerateDigraph(self): 959 print("Generating digraph...") 960 for a in self.transfers: 961 for b in self.transfers: 962 if a is b: 963 continue 964 965 # If the blocks written by A are read by B, then B needs to go before A. 966 i = a.tgt_ranges.intersect(b.src_ranges) 967 if i: 968 if b.src_name == "__ZERO": 969 # the cost of removing source blocks for the __ZERO domain 970 # is (nearly) zero. 971 size = 0 972 else: 973 size = i.size() 974 b.goes_before[a] = size 975 a.goes_after[b] = size 976 977 def FindTransfers(self): 978 """Parse the file_map to generate all the transfers.""" 979 980 def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id, 981 split=False): 982 """Wrapper function for adding a Transfer(). 983 984 For BBOTA v3, we need to stash source blocks for resumable feature. 985 However, with the growth of file size and the shrink of the cache 986 partition source blocks are too large to be stashed. If a file occupies 987 too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it 988 into smaller pieces by getting multiple Transfer()s. 989 990 The downside is that after splitting, we can no longer use imgdiff but 991 only bsdiff.""" 992 993 MAX_BLOCKS_PER_DIFF_TRANSFER = 1024 994 995 # We care about diff transfers only. 996 if style != "diff" or not split: 997 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 998 return 999 1000 # Change nothing for small files. 1001 if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and 1002 src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER): 1003 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 1004 return 1005 1006 pieces = 0 1007 while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and 1008 src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER): 1009 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1010 src_split_name = "%s-%d" % (src_name, pieces) 1011 tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 1012 src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 1013 Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style, 1014 by_id) 1015 1016 tgt_ranges = tgt_ranges.subtract(tgt_first) 1017 src_ranges = src_ranges.subtract(src_first) 1018 pieces += 1 1019 1020 # Handle remaining blocks. 1021 if tgt_ranges.size() or src_ranges.size(): 1022 # Must be both non-empty. 1023 assert tgt_ranges.size() and src_ranges.size() 1024 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1025 src_split_name = "%s-%d" % (src_name, pieces) 1026 Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style, 1027 by_id) 1028 1029 empty = RangeSet() 1030 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 1031 if tgt_fn == "__ZERO": 1032 # the special "__ZERO" domain is all the blocks not contained 1033 # in any file and that are filled with zeros. We have a 1034 # special transfer style for zero blocks. 1035 src_ranges = self.src.file_map.get("__ZERO", empty) 1036 AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 1037 "zero", self.transfers) 1038 continue 1039 1040 elif tgt_fn == "__COPY": 1041 # "__COPY" domain includes all the blocks not contained in any 1042 # file and that need to be copied unconditionally to the target. 1043 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1044 continue 1045 1046 elif tgt_fn in self.src.file_map: 1047 # Look for an exact pathname match in the source. 1048 AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 1049 "diff", self.transfers, self.version >= 3) 1050 continue 1051 1052 b = os.path.basename(tgt_fn) 1053 if b in self.src_basenames: 1054 # Look for an exact basename match in the source. 1055 src_fn = self.src_basenames[b] 1056 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1057 "diff", self.transfers, self.version >= 3) 1058 continue 1059 1060 b = re.sub("[0-9]+", "#", b) 1061 if b in self.src_numpatterns: 1062 # Look for a 'number pattern' match (a basename match after 1063 # all runs of digits are replaced by "#"). (This is useful 1064 # for .so files that contain version numbers in the filename 1065 # that get bumped.) 1066 src_fn = self.src_numpatterns[b] 1067 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1068 "diff", self.transfers, self.version >= 3) 1069 continue 1070 1071 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1072 1073 def AbbreviateSourceNames(self): 1074 for k in self.src.file_map.keys(): 1075 b = os.path.basename(k) 1076 self.src_basenames[b] = k 1077 b = re.sub("[0-9]+", "#", b) 1078 self.src_numpatterns[b] = k 1079 1080 @staticmethod 1081 def AssertPartition(total, seq): 1082 """Assert that all the RangeSets in 'seq' form a partition of the 1083 'total' RangeSet (ie, they are nonintersecting and their union 1084 equals 'total').""" 1085 so_far = RangeSet() 1086 for i in seq: 1087 assert not so_far.overlaps(i) 1088 so_far = so_far.union(i) 1089 assert so_far == total 1090