blockimgdiff.py revision 4fcb77e4e35db7fccb7a35ad650a85dc29bd2e73
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import common 20import heapq 21import itertools 22import multiprocessing 23import os 24import re 25import subprocess 26import threading 27import tempfile 28 29from rangelib import RangeSet 30 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34 35def compute_patch(src, tgt, imgdiff=False): 36 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 37 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 38 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 39 os.close(patchfd) 40 41 try: 42 with os.fdopen(srcfd, "wb") as f_src: 43 for p in src: 44 f_src.write(p) 45 46 with os.fdopen(tgtfd, "wb") as f_tgt: 47 for p in tgt: 48 f_tgt.write(p) 49 try: 50 os.unlink(patchfile) 51 except OSError: 52 pass 53 if imgdiff: 54 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 55 stdout=open("/dev/null", "a"), 56 stderr=subprocess.STDOUT) 57 else: 58 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 59 60 if p: 61 raise ValueError("diff failed: " + str(p)) 62 63 with open(patchfile, "rb") as f: 64 return f.read() 65 finally: 66 try: 67 os.unlink(srcfile) 68 os.unlink(tgtfile) 69 os.unlink(patchfile) 70 except OSError: 71 pass 72 73 74class Image(object): 75 def ReadRangeSet(self, ranges): 76 raise NotImplementedError 77 78 def TotalSha1(self, include_clobbered_blocks=False): 79 raise NotImplementedError 80 81 82class EmptyImage(Image): 83 """A zero-length image.""" 84 blocksize = 4096 85 care_map = RangeSet() 86 clobbered_blocks = RangeSet() 87 extended = RangeSet() 88 total_blocks = 0 89 file_map = {} 90 def ReadRangeSet(self, ranges): 91 return () 92 def TotalSha1(self, include_clobbered_blocks=False): 93 # EmptyImage always carries empty clobbered_blocks, so 94 # include_clobbered_blocks can be ignored. 95 assert self.clobbered_blocks.size() == 0 96 return sha1().hexdigest() 97 98 99class DataImage(Image): 100 """An image wrapped around a single string of data.""" 101 102 def __init__(self, data, trim=False, pad=False): 103 self.data = data 104 self.blocksize = 4096 105 106 assert not (trim and pad) 107 108 partial = len(self.data) % self.blocksize 109 padded = False 110 if partial > 0: 111 if trim: 112 self.data = self.data[:-partial] 113 elif pad: 114 self.data += '\0' * (self.blocksize - partial) 115 padded = True 116 else: 117 raise ValueError(("data for DataImage must be multiple of %d bytes " 118 "unless trim or pad is specified") % 119 (self.blocksize,)) 120 121 assert len(self.data) % self.blocksize == 0 122 123 self.total_blocks = len(self.data) / self.blocksize 124 self.care_map = RangeSet(data=(0, self.total_blocks)) 125 # When the last block is padded, we always write the whole block even for 126 # incremental OTAs. Because otherwise the last block may get skipped if 127 # unchanged for an incremental, but would fail the post-install 128 # verification if it has non-zero contents in the padding bytes. 129 # Bug: 23828506 130 if padded: 131 clobbered_blocks = [self.total_blocks-1, self.total_blocks] 132 else: 133 clobbered_blocks = [] 134 self.clobbered_blocks = clobbered_blocks 135 self.extended = RangeSet() 136 137 zero_blocks = [] 138 nonzero_blocks = [] 139 reference = '\0' * self.blocksize 140 141 for i in range(self.total_blocks-1 if padded else self.total_blocks): 142 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 143 if d == reference: 144 zero_blocks.append(i) 145 zero_blocks.append(i+1) 146 else: 147 nonzero_blocks.append(i) 148 nonzero_blocks.append(i+1) 149 150 assert zero_blocks or nonzero_blocks or clobbered_blocks 151 152 self.file_map = dict() 153 if zero_blocks: 154 self.file_map["__ZERO"] = RangeSet(data=zero_blocks) 155 if nonzero_blocks: 156 self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) 157 if clobbered_blocks: 158 self.file_map["__COPY"] = RangeSet(data=clobbered_blocks) 159 160 def ReadRangeSet(self, ranges): 161 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 162 163 def TotalSha1(self, include_clobbered_blocks=False): 164 if not include_clobbered_blocks: 165 ranges = self.care_map.subtract(self.clobbered_blocks) 166 return sha1(self.ReadRangeSet(ranges)).hexdigest() 167 else: 168 return sha1(self.data).hexdigest() 169 170 171class Transfer(object): 172 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 173 self.tgt_name = tgt_name 174 self.src_name = src_name 175 self.tgt_ranges = tgt_ranges 176 self.src_ranges = src_ranges 177 self.style = style 178 self.intact = (getattr(tgt_ranges, "monotonic", False) and 179 getattr(src_ranges, "monotonic", False)) 180 181 # We use OrderedDict rather than dict so that the output is repeatable; 182 # otherwise it would depend on the hash values of the Transfer objects. 183 self.goes_before = OrderedDict() 184 self.goes_after = OrderedDict() 185 186 self.stash_before = [] 187 self.use_stash = [] 188 189 self.id = len(by_id) 190 by_id.append(self) 191 192 def NetStashChange(self): 193 return (sum(sr.size() for (_, sr) in self.stash_before) - 194 sum(sr.size() for (_, sr) in self.use_stash)) 195 196 def ConvertToNew(self): 197 assert self.style != "new" 198 self.use_stash = [] 199 self.style = "new" 200 self.src_ranges = RangeSet() 201 202 def __str__(self): 203 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 204 " to " + str(self.tgt_ranges) + ">") 205 206 207# BlockImageDiff works on two image objects. An image object is 208# anything that provides the following attributes: 209# 210# blocksize: the size in bytes of a block, currently must be 4096. 211# 212# total_blocks: the total size of the partition/image, in blocks. 213# 214# care_map: a RangeSet containing which blocks (in the range [0, 215# total_blocks) we actually care about; i.e. which blocks contain 216# data. 217# 218# file_map: a dict that partitions the blocks contained in care_map 219# into smaller domains that are useful for doing diffs on. 220# (Typically a domain is a file, and the key in file_map is the 221# pathname.) 222# 223# clobbered_blocks: a RangeSet containing which blocks contain data 224# but may be altered by the FS. They need to be excluded when 225# verifying the partition integrity. 226# 227# ReadRangeSet(): a function that takes a RangeSet and returns the 228# data contained in the image blocks of that RangeSet. The data 229# is returned as a list or tuple of strings; concatenating the 230# elements together should produce the requested data. 231# Implementations are free to break up the data into list/tuple 232# elements in any way that is convenient. 233# 234# TotalSha1(): a function that returns (as a hex string) the SHA-1 235# hash of all the data in the image (ie, all the blocks in the 236# care_map minus clobbered_blocks, or including the clobbered 237# blocks if include_clobbered_blocks is True). 238# 239# When creating a BlockImageDiff, the src image may be None, in which 240# case the list of transfers produced will never read from the 241# original image. 242 243class BlockImageDiff(object): 244 def __init__(self, tgt, src=None, threads=None, version=3): 245 if threads is None: 246 threads = multiprocessing.cpu_count() // 2 247 if threads == 0: 248 threads = 1 249 self.threads = threads 250 self.version = version 251 self.transfers = [] 252 self.src_basenames = {} 253 self.src_numpatterns = {} 254 255 assert version in (1, 2, 3) 256 257 self.tgt = tgt 258 if src is None: 259 src = EmptyImage() 260 self.src = src 261 262 # The updater code that installs the patch always uses 4k blocks. 263 assert tgt.blocksize == 4096 264 assert src.blocksize == 4096 265 266 # The range sets in each filemap should comprise a partition of 267 # the care map. 268 self.AssertPartition(src.care_map, src.file_map.values()) 269 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 270 271 def Compute(self, prefix): 272 # When looking for a source file to use as the diff input for a 273 # target file, we try: 274 # 1) an exact path match if available, otherwise 275 # 2) a exact basename match if available, otherwise 276 # 3) a basename match after all runs of digits are replaced by 277 # "#" if available, otherwise 278 # 4) we have no source for this target. 279 self.AbbreviateSourceNames() 280 self.FindTransfers() 281 282 # Find the ordering dependencies among transfers (this is O(n^2) 283 # in the number of transfers). 284 self.GenerateDigraph() 285 # Find a sequence of transfers that satisfies as many ordering 286 # dependencies as possible (heuristically). 287 self.FindVertexSequence() 288 # Fix up the ordering dependencies that the sequence didn't 289 # satisfy. 290 if self.version == 1: 291 self.RemoveBackwardEdges() 292 else: 293 self.ReverseBackwardEdges() 294 self.ImproveVertexSequence() 295 296 # Ensure the runtime stash size is under the limit. 297 if self.version >= 2 and common.OPTIONS.cache_size is not None: 298 self.ReviseStashSize() 299 300 # Double-check our work. 301 self.AssertSequenceGood() 302 303 self.ComputePatches(prefix) 304 self.WriteTransfers(prefix) 305 306 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 307 data = source.ReadRangeSet(ranges) 308 ctx = sha1() 309 310 for p in data: 311 ctx.update(p) 312 313 return ctx.hexdigest() 314 315 def WriteTransfers(self, prefix): 316 out = [] 317 318 total = 0 319 320 stashes = {} 321 stashed_blocks = 0 322 max_stashed_blocks = 0 323 324 free_stash_ids = [] 325 next_stash_id = 0 326 327 for xf in self.transfers: 328 329 if self.version < 2: 330 assert not xf.stash_before 331 assert not xf.use_stash 332 333 for s, sr in xf.stash_before: 334 assert s not in stashes 335 if free_stash_ids: 336 sid = heapq.heappop(free_stash_ids) 337 else: 338 sid = next_stash_id 339 next_stash_id += 1 340 stashes[s] = sid 341 stashed_blocks += sr.size() 342 if self.version == 2: 343 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 344 else: 345 sh = self.HashBlocks(self.src, sr) 346 if sh in stashes: 347 stashes[sh] += 1 348 else: 349 stashes[sh] = 1 350 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 351 352 if stashed_blocks > max_stashed_blocks: 353 max_stashed_blocks = stashed_blocks 354 355 free_string = [] 356 357 if self.version == 1: 358 src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else "" 359 elif self.version >= 2: 360 361 # <# blocks> <src ranges> 362 # OR 363 # <# blocks> <src ranges> <src locs> <stash refs...> 364 # OR 365 # <# blocks> - <stash refs...> 366 367 size = xf.src_ranges.size() 368 src_str = [str(size)] 369 370 unstashed_src_ranges = xf.src_ranges 371 mapped_stashes = [] 372 for s, sr in xf.use_stash: 373 sid = stashes.pop(s) 374 stashed_blocks -= sr.size() 375 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 376 sh = self.HashBlocks(self.src, sr) 377 sr = xf.src_ranges.map_within(sr) 378 mapped_stashes.append(sr) 379 if self.version == 2: 380 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 381 # A stash will be used only once. We need to free the stash 382 # immediately after the use, instead of waiting for the automatic 383 # clean-up at the end. Because otherwise it may take up extra space 384 # and lead to OTA failures. 385 # Bug: 23119955 386 free_string.append("free %d\n" % (sid,)) 387 else: 388 assert sh in stashes 389 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 390 stashes[sh] -= 1 391 if stashes[sh] == 0: 392 free_string.append("free %s\n" % (sh)) 393 stashes.pop(sh) 394 heapq.heappush(free_stash_ids, sid) 395 396 if unstashed_src_ranges: 397 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 398 if xf.use_stash: 399 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 400 src_str.insert(2, mapped_unstashed.to_string_raw()) 401 mapped_stashes.append(mapped_unstashed) 402 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 403 else: 404 src_str.insert(1, "-") 405 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 406 407 src_str = " ".join(src_str) 408 409 # all versions: 410 # zero <rangeset> 411 # new <rangeset> 412 # erase <rangeset> 413 # 414 # version 1: 415 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 416 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 417 # move <src rangeset> <tgt rangeset> 418 # 419 # version 2: 420 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 421 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 422 # move <tgt rangeset> <src_str> 423 # 424 # version 3: 425 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 426 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 427 # move hash <tgt rangeset> <src_str> 428 429 tgt_size = xf.tgt_ranges.size() 430 431 if xf.style == "new": 432 assert xf.tgt_ranges 433 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 434 total += tgt_size 435 elif xf.style == "move": 436 assert xf.tgt_ranges 437 assert xf.src_ranges.size() == tgt_size 438 if xf.src_ranges != xf.tgt_ranges: 439 if self.version == 1: 440 out.append("%s %s %s\n" % ( 441 xf.style, 442 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 443 elif self.version == 2: 444 out.append("%s %s %s\n" % ( 445 xf.style, 446 xf.tgt_ranges.to_string_raw(), src_str)) 447 elif self.version >= 3: 448 # take into account automatic stashing of overlapping blocks 449 if xf.src_ranges.overlaps(xf.tgt_ranges): 450 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 451 if temp_stash_usage > max_stashed_blocks: 452 max_stashed_blocks = temp_stash_usage 453 454 out.append("%s %s %s %s\n" % ( 455 xf.style, 456 self.HashBlocks(self.tgt, xf.tgt_ranges), 457 xf.tgt_ranges.to_string_raw(), src_str)) 458 total += tgt_size 459 elif xf.style in ("bsdiff", "imgdiff"): 460 assert xf.tgt_ranges 461 assert xf.src_ranges 462 if self.version == 1: 463 out.append("%s %d %d %s %s\n" % ( 464 xf.style, xf.patch_start, xf.patch_len, 465 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 466 elif self.version == 2: 467 out.append("%s %d %d %s %s\n" % ( 468 xf.style, xf.patch_start, xf.patch_len, 469 xf.tgt_ranges.to_string_raw(), src_str)) 470 elif self.version >= 3: 471 # take into account automatic stashing of overlapping blocks 472 if xf.src_ranges.overlaps(xf.tgt_ranges): 473 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 474 if temp_stash_usage > max_stashed_blocks: 475 max_stashed_blocks = temp_stash_usage 476 477 out.append("%s %d %d %s %s %s %s\n" % ( 478 xf.style, 479 xf.patch_start, xf.patch_len, 480 self.HashBlocks(self.src, xf.src_ranges), 481 self.HashBlocks(self.tgt, xf.tgt_ranges), 482 xf.tgt_ranges.to_string_raw(), src_str)) 483 total += tgt_size 484 elif xf.style == "zero": 485 assert xf.tgt_ranges 486 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 487 if to_zero: 488 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 489 total += to_zero.size() 490 else: 491 raise ValueError("unknown transfer style '%s'\n" % xf.style) 492 493 if free_string: 494 out.append("".join(free_string)) 495 496 if self.version >= 2 and common.OPTIONS.cache_size is not None: 497 # Sanity check: abort if we're going to need more stash space than 498 # the allowed size (cache_size * threshold). There are two purposes 499 # of having a threshold here. a) Part of the cache may have been 500 # occupied by some recovery logs. b) It will buy us some time to deal 501 # with the oversize issue. 502 cache_size = common.OPTIONS.cache_size 503 stash_threshold = common.OPTIONS.stash_threshold 504 max_allowed = cache_size * stash_threshold 505 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 506 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 507 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 508 self.tgt.blocksize, max_allowed, cache_size, 509 stash_threshold) 510 511 # Zero out extended blocks as a workaround for bug 20881595. 512 if self.tgt.extended: 513 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 514 total += self.tgt.extended.size() 515 516 # We erase all the blocks on the partition that a) don't contain useful 517 # data in the new image and b) will not be touched by dm-verity. 518 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 519 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 520 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 521 if new_dontcare: 522 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 523 524 out.insert(0, "%d\n" % (self.version,)) # format version number 525 out.insert(1, "%d\n" % (total,)) 526 if self.version >= 2: 527 # version 2 only: after the total block count, we give the number 528 # of stash slots needed, and the maximum size needed (in blocks) 529 out.insert(2, str(next_stash_id) + "\n") 530 out.insert(3, str(max_stashed_blocks) + "\n") 531 532 with open(prefix + ".transfer.list", "wb") as f: 533 for i in out: 534 f.write(i) 535 536 if self.version >= 2: 537 max_stashed_size = max_stashed_blocks * self.tgt.blocksize 538 OPTIONS = common.OPTIONS 539 if OPTIONS.cache_size is not None: 540 max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold 541 print("max stashed blocks: %d (%d bytes), " 542 "limit: %d bytes (%.2f%%)\n" % ( 543 max_stashed_blocks, max_stashed_size, max_allowed, 544 max_stashed_size * 100.0 / max_allowed)) 545 else: 546 print("max stashed blocks: %d (%d bytes), limit: <unknown>\n" % ( 547 max_stashed_blocks, max_stashed_size)) 548 549 def ReviseStashSize(self): 550 print("Revising stash size...") 551 stashes = {} 552 553 # Create the map between a stash and its def/use points. For example, for a 554 # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd). 555 for xf in self.transfers: 556 # Command xf defines (stores) all the stashes in stash_before. 557 for idx, sr in xf.stash_before: 558 stashes[idx] = (sr, xf) 559 560 # Record all the stashes command xf uses. 561 for idx, _ in xf.use_stash: 562 stashes[idx] += (xf,) 563 564 # Compute the maximum blocks available for stash based on /cache size and 565 # the threshold. 566 cache_size = common.OPTIONS.cache_size 567 stash_threshold = common.OPTIONS.stash_threshold 568 max_allowed = cache_size * stash_threshold / self.tgt.blocksize 569 570 stashed_blocks = 0 571 new_blocks = 0 572 573 # Now go through all the commands. Compute the required stash size on the 574 # fly. If a command requires excess stash than available, it deletes the 575 # stash by replacing the command that uses the stash with a "new" command 576 # instead. 577 for xf in self.transfers: 578 replaced_cmds = [] 579 580 # xf.stash_before generates explicit stash commands. 581 for idx, sr in xf.stash_before: 582 if stashed_blocks + sr.size() > max_allowed: 583 # We cannot stash this one for a later command. Find out the command 584 # that will use this stash and replace the command with "new". 585 use_cmd = stashes[idx][2] 586 replaced_cmds.append(use_cmd) 587 print("%10d %9s %s" % (sr.size(), "explicit", use_cmd)) 588 else: 589 stashed_blocks += sr.size() 590 591 # xf.use_stash generates free commands. 592 for _, sr in xf.use_stash: 593 stashed_blocks -= sr.size() 594 595 # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to 596 # ComputePatches(), they both have the style of "diff". 597 if xf.style == "diff" and self.version >= 3: 598 assert xf.tgt_ranges and xf.src_ranges 599 if xf.src_ranges.overlaps(xf.tgt_ranges): 600 if stashed_blocks + xf.src_ranges.size() > max_allowed: 601 replaced_cmds.append(xf) 602 print("%10d %9s %s" % (xf.src_ranges.size(), "implicit", xf)) 603 604 # Replace the commands in replaced_cmds with "new"s. 605 for cmd in replaced_cmds: 606 # It no longer uses any commands in "use_stash". Remove the def points 607 # for all those stashes. 608 for idx, sr in cmd.use_stash: 609 def_cmd = stashes[idx][1] 610 assert (idx, sr) in def_cmd.stash_before 611 def_cmd.stash_before.remove((idx, sr)) 612 new_blocks += sr.size() 613 614 cmd.ConvertToNew() 615 616 print(" Total %d blocks are packed as new blocks due to insufficient " 617 "cache size." % (new_blocks,)) 618 619 def ComputePatches(self, prefix): 620 print("Reticulating splines...") 621 diff_q = [] 622 patch_num = 0 623 with open(prefix + ".new.dat", "wb") as new_f: 624 for xf in self.transfers: 625 if xf.style == "zero": 626 pass 627 elif xf.style == "new": 628 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 629 new_f.write(piece) 630 elif xf.style == "diff": 631 src = self.src.ReadRangeSet(xf.src_ranges) 632 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 633 634 # We can't compare src and tgt directly because they may have 635 # the same content but be broken up into blocks differently, eg: 636 # 637 # ["he", "llo"] vs ["h", "ello"] 638 # 639 # We want those to compare equal, ideally without having to 640 # actually concatenate the strings (these may be tens of 641 # megabytes). 642 643 src_sha1 = sha1() 644 for p in src: 645 src_sha1.update(p) 646 tgt_sha1 = sha1() 647 tgt_size = 0 648 for p in tgt: 649 tgt_sha1.update(p) 650 tgt_size += len(p) 651 652 if src_sha1.digest() == tgt_sha1.digest(): 653 # These are identical; we don't need to generate a patch, 654 # just issue copy commands on the device. 655 xf.style = "move" 656 else: 657 # For files in zip format (eg, APKs, JARs, etc.) we would 658 # like to use imgdiff -z if possible (because it usually 659 # produces significantly smaller patches than bsdiff). 660 # This is permissible if: 661 # 662 # - the source and target files are monotonic (ie, the 663 # data is stored with blocks in increasing order), and 664 # - we haven't removed any blocks from the source set. 665 # 666 # If these conditions are satisfied then appending all the 667 # blocks in the set together in order will produce a valid 668 # zip file (plus possibly extra zeros in the last block), 669 # which is what imgdiff needs to operate. (imgdiff is 670 # fine with extra zeros at the end of the file.) 671 imgdiff = (xf.intact and 672 xf.tgt_name.split(".")[-1].lower() 673 in ("apk", "jar", "zip")) 674 xf.style = "imgdiff" if imgdiff else "bsdiff" 675 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 676 patch_num += 1 677 678 else: 679 assert False, "unknown style " + xf.style 680 681 if diff_q: 682 if self.threads > 1: 683 print("Computing patches (using %d threads)..." % (self.threads,)) 684 else: 685 print("Computing patches...") 686 diff_q.sort() 687 688 patches = [None] * patch_num 689 690 # TODO: Rewrite with multiprocessing.ThreadPool? 691 lock = threading.Lock() 692 def diff_worker(): 693 while True: 694 with lock: 695 if not diff_q: 696 return 697 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 698 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 699 size = len(patch) 700 with lock: 701 patches[patchnum] = (patch, xf) 702 print("%10d %10d (%6.2f%%) %7s %s" % ( 703 size, tgt_size, size * 100.0 / tgt_size, xf.style, 704 xf.tgt_name if xf.tgt_name == xf.src_name else ( 705 xf.tgt_name + " (from " + xf.src_name + ")"))) 706 707 threads = [threading.Thread(target=diff_worker) 708 for _ in range(self.threads)] 709 for th in threads: 710 th.start() 711 while threads: 712 threads.pop().join() 713 else: 714 patches = [] 715 716 p = 0 717 with open(prefix + ".patch.dat", "wb") as patch_f: 718 for patch, xf in patches: 719 xf.patch_start = p 720 xf.patch_len = len(patch) 721 patch_f.write(patch) 722 p += len(patch) 723 724 def AssertSequenceGood(self): 725 # Simulate the sequences of transfers we will output, and check that: 726 # - we never read a block after writing it, and 727 # - we write every block we care about exactly once. 728 729 # Start with no blocks having been touched yet. 730 touched = RangeSet() 731 732 # Imagine processing the transfers in order. 733 for xf in self.transfers: 734 # Check that the input blocks for this transfer haven't yet been touched. 735 736 x = xf.src_ranges 737 if self.version >= 2: 738 for _, sr in xf.use_stash: 739 x = x.subtract(sr) 740 741 assert not touched.overlaps(x) 742 # Check that the output blocks for this transfer haven't yet been touched. 743 assert not touched.overlaps(xf.tgt_ranges) 744 # Touch all the blocks written by this transfer. 745 touched = touched.union(xf.tgt_ranges) 746 747 # Check that we've written every target block. 748 assert touched == self.tgt.care_map 749 750 def ImproveVertexSequence(self): 751 print("Improving vertex order...") 752 753 # At this point our digraph is acyclic; we reversed any edges that 754 # were backwards in the heuristically-generated sequence. The 755 # previously-generated order is still acceptable, but we hope to 756 # find a better order that needs less memory for stashed data. 757 # Now we do a topological sort to generate a new vertex order, 758 # using a greedy algorithm to choose which vertex goes next 759 # whenever we have a choice. 760 761 # Make a copy of the edge set; this copy will get destroyed by the 762 # algorithm. 763 for xf in self.transfers: 764 xf.incoming = xf.goes_after.copy() 765 xf.outgoing = xf.goes_before.copy() 766 767 L = [] # the new vertex order 768 769 # S is the set of sources in the remaining graph; we always choose 770 # the one that leaves the least amount of stashed data after it's 771 # executed. 772 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 773 if not u.incoming] 774 heapq.heapify(S) 775 776 while S: 777 _, _, xf = heapq.heappop(S) 778 L.append(xf) 779 for u in xf.outgoing: 780 del u.incoming[xf] 781 if not u.incoming: 782 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 783 784 # if this fails then our graph had a cycle. 785 assert len(L) == len(self.transfers) 786 787 self.transfers = L 788 for i, xf in enumerate(L): 789 xf.order = i 790 791 def RemoveBackwardEdges(self): 792 print("Removing backward edges...") 793 in_order = 0 794 out_of_order = 0 795 lost_source = 0 796 797 for xf in self.transfers: 798 lost = 0 799 size = xf.src_ranges.size() 800 for u in xf.goes_before: 801 # xf should go before u 802 if xf.order < u.order: 803 # it does, hurray! 804 in_order += 1 805 else: 806 # it doesn't, boo. trim the blocks that u writes from xf's 807 # source, so that xf can go after u. 808 out_of_order += 1 809 assert xf.src_ranges.overlaps(u.tgt_ranges) 810 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 811 xf.intact = False 812 813 if xf.style == "diff" and not xf.src_ranges: 814 # nothing left to diff from; treat as new data 815 xf.style = "new" 816 817 lost = size - xf.src_ranges.size() 818 lost_source += lost 819 820 print((" %d/%d dependencies (%.2f%%) were violated; " 821 "%d source blocks removed.") % 822 (out_of_order, in_order + out_of_order, 823 (out_of_order * 100.0 / (in_order + out_of_order)) 824 if (in_order + out_of_order) else 0.0, 825 lost_source)) 826 827 def ReverseBackwardEdges(self): 828 print("Reversing backward edges...") 829 in_order = 0 830 out_of_order = 0 831 stashes = 0 832 stash_size = 0 833 834 for xf in self.transfers: 835 for u in xf.goes_before.copy(): 836 # xf should go before u 837 if xf.order < u.order: 838 # it does, hurray! 839 in_order += 1 840 else: 841 # it doesn't, boo. modify u to stash the blocks that it 842 # writes that xf wants to read, and then require u to go 843 # before xf. 844 out_of_order += 1 845 846 overlap = xf.src_ranges.intersect(u.tgt_ranges) 847 assert overlap 848 849 u.stash_before.append((stashes, overlap)) 850 xf.use_stash.append((stashes, overlap)) 851 stashes += 1 852 stash_size += overlap.size() 853 854 # reverse the edge direction; now xf must go after u 855 del xf.goes_before[u] 856 del u.goes_after[xf] 857 xf.goes_after[u] = None # value doesn't matter 858 u.goes_before[xf] = None 859 860 print((" %d/%d dependencies (%.2f%%) were violated; " 861 "%d source blocks stashed.") % 862 (out_of_order, in_order + out_of_order, 863 (out_of_order * 100.0 / (in_order + out_of_order)) 864 if (in_order + out_of_order) else 0.0, 865 stash_size)) 866 867 def FindVertexSequence(self): 868 print("Finding vertex sequence...") 869 870 # This is based on "A Fast & Effective Heuristic for the Feedback 871 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 872 # it as starting with the digraph G and moving all the vertices to 873 # be on a horizontal line in some order, trying to minimize the 874 # number of edges that end up pointing to the left. Left-pointing 875 # edges will get removed to turn the digraph into a DAG. In this 876 # case each edge has a weight which is the number of source blocks 877 # we'll lose if that edge is removed; we try to minimize the total 878 # weight rather than just the number of edges. 879 880 # Make a copy of the edge set; this copy will get destroyed by the 881 # algorithm. 882 for xf in self.transfers: 883 xf.incoming = xf.goes_after.copy() 884 xf.outgoing = xf.goes_before.copy() 885 886 # We use an OrderedDict instead of just a set so that the output 887 # is repeatable; otherwise it would depend on the hash values of 888 # the transfer objects. 889 G = OrderedDict() 890 for xf in self.transfers: 891 G[xf] = None 892 s1 = deque() # the left side of the sequence, built from left to right 893 s2 = deque() # the right side of the sequence, built from right to left 894 895 while G: 896 897 # Put all sinks at the end of the sequence. 898 while True: 899 sinks = [u for u in G if not u.outgoing] 900 if not sinks: 901 break 902 for u in sinks: 903 s2.appendleft(u) 904 del G[u] 905 for iu in u.incoming: 906 del iu.outgoing[u] 907 908 # Put all the sources at the beginning of the sequence. 909 while True: 910 sources = [u for u in G if not u.incoming] 911 if not sources: 912 break 913 for u in sources: 914 s1.append(u) 915 del G[u] 916 for iu in u.outgoing: 917 del iu.incoming[u] 918 919 if not G: 920 break 921 922 # Find the "best" vertex to put next. "Best" is the one that 923 # maximizes the net difference in source blocks saved we get by 924 # pretending it's a source rather than a sink. 925 926 max_d = None 927 best_u = None 928 for u in G: 929 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 930 if best_u is None or d > max_d: 931 max_d = d 932 best_u = u 933 934 u = best_u 935 s1.append(u) 936 del G[u] 937 for iu in u.outgoing: 938 del iu.incoming[u] 939 for iu in u.incoming: 940 del iu.outgoing[u] 941 942 # Now record the sequence in the 'order' field of each transfer, 943 # and by rearranging self.transfers to be in the chosen sequence. 944 945 new_transfers = [] 946 for x in itertools.chain(s1, s2): 947 x.order = len(new_transfers) 948 new_transfers.append(x) 949 del x.incoming 950 del x.outgoing 951 952 self.transfers = new_transfers 953 954 def GenerateDigraph(self): 955 print("Generating digraph...") 956 for a in self.transfers: 957 for b in self.transfers: 958 if a is b: 959 continue 960 961 # If the blocks written by A are read by B, then B needs to go before A. 962 i = a.tgt_ranges.intersect(b.src_ranges) 963 if i: 964 if b.src_name == "__ZERO": 965 # the cost of removing source blocks for the __ZERO domain 966 # is (nearly) zero. 967 size = 0 968 else: 969 size = i.size() 970 b.goes_before[a] = size 971 a.goes_after[b] = size 972 973 def FindTransfers(self): 974 """Parse the file_map to generate all the transfers.""" 975 976 def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id, 977 split=False): 978 """Wrapper function for adding a Transfer(). 979 980 For BBOTA v3, we need to stash source blocks for resumable feature. 981 However, with the growth of file size and the shrink of the cache 982 partition source blocks are too large to be stashed. If a file occupies 983 too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it 984 into smaller pieces by getting multiple Transfer()s. 985 986 The downside is that after splitting, we can no longer use imgdiff but 987 only bsdiff.""" 988 989 MAX_BLOCKS_PER_DIFF_TRANSFER = 1024 990 991 # We care about diff transfers only. 992 if style != "diff" or not split: 993 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 994 return 995 996 # Change nothing for small files. 997 if (tgt_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER and 998 src_ranges.size() <= MAX_BLOCKS_PER_DIFF_TRANSFER): 999 Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id) 1000 return 1001 1002 pieces = 0 1003 while (tgt_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER and 1004 src_ranges.size() > MAX_BLOCKS_PER_DIFF_TRANSFER): 1005 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1006 src_split_name = "%s-%d" % (src_name, pieces) 1007 tgt_first = tgt_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 1008 src_first = src_ranges.first(MAX_BLOCKS_PER_DIFF_TRANSFER) 1009 Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style, 1010 by_id) 1011 1012 tgt_ranges = tgt_ranges.subtract(tgt_first) 1013 src_ranges = src_ranges.subtract(src_first) 1014 pieces += 1 1015 1016 # Handle remaining blocks. 1017 if tgt_ranges.size() or src_ranges.size(): 1018 # Must be both non-empty. 1019 assert tgt_ranges.size() and src_ranges.size() 1020 tgt_split_name = "%s-%d" % (tgt_name, pieces) 1021 src_split_name = "%s-%d" % (src_name, pieces) 1022 Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style, 1023 by_id) 1024 1025 empty = RangeSet() 1026 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 1027 if tgt_fn == "__ZERO": 1028 # the special "__ZERO" domain is all the blocks not contained 1029 # in any file and that are filled with zeros. We have a 1030 # special transfer style for zero blocks. 1031 src_ranges = self.src.file_map.get("__ZERO", empty) 1032 AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 1033 "zero", self.transfers) 1034 continue 1035 1036 elif tgt_fn == "__COPY": 1037 # "__COPY" domain includes all the blocks not contained in any 1038 # file and that need to be copied unconditionally to the target. 1039 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1040 continue 1041 1042 elif tgt_fn in self.src.file_map: 1043 # Look for an exact pathname match in the source. 1044 AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 1045 "diff", self.transfers, self.version >= 3) 1046 continue 1047 1048 b = os.path.basename(tgt_fn) 1049 if b in self.src_basenames: 1050 # Look for an exact basename match in the source. 1051 src_fn = self.src_basenames[b] 1052 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1053 "diff", self.transfers, self.version >= 3) 1054 continue 1055 1056 b = re.sub("[0-9]+", "#", b) 1057 if b in self.src_numpatterns: 1058 # Look for a 'number pattern' match (a basename match after 1059 # all runs of digits are replaced by "#"). (This is useful 1060 # for .so files that contain version numbers in the filename 1061 # that get bumped.) 1062 src_fn = self.src_numpatterns[b] 1063 AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 1064 "diff", self.transfers, self.version >= 3) 1065 continue 1066 1067 AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 1068 1069 def AbbreviateSourceNames(self): 1070 for k in self.src.file_map.keys(): 1071 b = os.path.basename(k) 1072 self.src_basenames[b] = k 1073 b = re.sub("[0-9]+", "#", b) 1074 self.src_numpatterns[b] = k 1075 1076 @staticmethod 1077 def AssertPartition(total, seq): 1078 """Assert that all the RangeSets in 'seq' form a partition of the 1079 'total' RangeSet (ie, they are nonintersecting and their union 1080 equals 'total').""" 1081 so_far = RangeSet() 1082 for i in seq: 1083 assert not so_far.overlaps(i) 1084 so_far = so_far.union(i) 1085 assert so_far == total 1086