blockimgdiff.py revision 3203bc10f0bc6a1acce8afeee6d8df1f0f64f72c
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import re 24import subprocess 25import threading 26import tempfile 27 28from rangelib import RangeSet 29 30 31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 32 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72 73class Image(object): 74 def ReadRangeSet(self, ranges): 75 raise NotImplementedError 76 77 def TotalSha1(self, include_clobbered_blocks=False): 78 raise NotImplementedError 79 80 81class EmptyImage(Image): 82 """A zero-length image.""" 83 blocksize = 4096 84 care_map = RangeSet() 85 clobbered_blocks = RangeSet() 86 extended = RangeSet() 87 total_blocks = 0 88 file_map = {} 89 def ReadRangeSet(self, ranges): 90 return () 91 def TotalSha1(self, include_clobbered_blocks=False): 92 # EmptyImage always carries empty clobbered_blocks, so 93 # include_clobbered_blocks can be ignored. 94 assert self.clobbered_blocks.size() == 0 95 return sha1().hexdigest() 96 97 98class DataImage(Image): 99 """An image wrapped around a single string of data.""" 100 101 def __init__(self, data, trim=False, pad=False): 102 self.data = data 103 self.blocksize = 4096 104 105 assert not (trim and pad) 106 107 partial = len(self.data) % self.blocksize 108 padded = False 109 if partial > 0: 110 if trim: 111 self.data = self.data[:-partial] 112 elif pad: 113 self.data += '\0' * (self.blocksize - partial) 114 padded = True 115 else: 116 raise ValueError(("data for DataImage must be multiple of %d bytes " 117 "unless trim or pad is specified") % 118 (self.blocksize,)) 119 120 assert len(self.data) % self.blocksize == 0 121 122 self.total_blocks = len(self.data) / self.blocksize 123 self.care_map = RangeSet(data=(0, self.total_blocks)) 124 # When the last block is padded, we always write the whole block even for 125 # incremental OTAs. Because otherwise the last block may get skipped if 126 # unchanged for an incremental, but would fail the post-install 127 # verification if it has non-zero contents in the padding bytes. 128 # Bug: 23828506 129 if padded: 130 self.clobbered_blocks = RangeSet( 131 data=(self.total_blocks-1, self.total_blocks)) 132 else: 133 self.clobbered_blocks = RangeSet() 134 self.extended = RangeSet() 135 136 zero_blocks = [] 137 nonzero_blocks = [] 138 reference = '\0' * self.blocksize 139 140 for i in range(self.total_blocks-1 if padded else self.total_blocks): 141 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 142 if d == reference: 143 zero_blocks.append(i) 144 zero_blocks.append(i+1) 145 else: 146 nonzero_blocks.append(i) 147 nonzero_blocks.append(i+1) 148 149 self.file_map = {"__ZERO": RangeSet(zero_blocks), 150 "__NONZERO": RangeSet(nonzero_blocks)} 151 152 if self.clobbered_blocks: 153 self.file_map["__COPY"] = self.clobbered_blocks 154 155 def ReadRangeSet(self, ranges): 156 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 157 158 def TotalSha1(self, include_clobbered_blocks=False): 159 if not include_clobbered_blocks: 160 ranges = self.care_map.subtract(self.clobbered_blocks) 161 return sha1(self.ReadRangeSet(ranges)).hexdigest() 162 else: 163 return sha1(self.data).hexdigest() 164 165 166class Transfer(object): 167 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 168 self.tgt_name = tgt_name 169 self.src_name = src_name 170 self.tgt_ranges = tgt_ranges 171 self.src_ranges = src_ranges 172 self.style = style 173 self.intact = (getattr(tgt_ranges, "monotonic", False) and 174 getattr(src_ranges, "monotonic", False)) 175 176 # We use OrderedDict rather than dict so that the output is repeatable; 177 # otherwise it would depend on the hash values of the Transfer objects. 178 self.goes_before = OrderedDict() 179 self.goes_after = OrderedDict() 180 181 self.stash_before = [] 182 self.use_stash = [] 183 184 self.id = len(by_id) 185 by_id.append(self) 186 187 def NetStashChange(self): 188 return (sum(sr.size() for (_, sr) in self.stash_before) - 189 sum(sr.size() for (_, sr) in self.use_stash)) 190 191 def __str__(self): 192 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 193 " to " + str(self.tgt_ranges) + ">") 194 195 196# BlockImageDiff works on two image objects. An image object is 197# anything that provides the following attributes: 198# 199# blocksize: the size in bytes of a block, currently must be 4096. 200# 201# total_blocks: the total size of the partition/image, in blocks. 202# 203# care_map: a RangeSet containing which blocks (in the range [0, 204# total_blocks) we actually care about; i.e. which blocks contain 205# data. 206# 207# file_map: a dict that partitions the blocks contained in care_map 208# into smaller domains that are useful for doing diffs on. 209# (Typically a domain is a file, and the key in file_map is the 210# pathname.) 211# 212# clobbered_blocks: a RangeSet containing which blocks contain data 213# but may be altered by the FS. They need to be excluded when 214# verifying the partition integrity. 215# 216# ReadRangeSet(): a function that takes a RangeSet and returns the 217# data contained in the image blocks of that RangeSet. The data 218# is returned as a list or tuple of strings; concatenating the 219# elements together should produce the requested data. 220# Implementations are free to break up the data into list/tuple 221# elements in any way that is convenient. 222# 223# TotalSha1(): a function that returns (as a hex string) the SHA-1 224# hash of all the data in the image (ie, all the blocks in the 225# care_map minus clobbered_blocks, or including the clobbered 226# blocks if include_clobbered_blocks is True). 227# 228# When creating a BlockImageDiff, the src image may be None, in which 229# case the list of transfers produced will never read from the 230# original image. 231 232class BlockImageDiff(object): 233 def __init__(self, tgt, src=None, threads=None, version=3): 234 if threads is None: 235 threads = multiprocessing.cpu_count() // 2 236 if threads == 0: 237 threads = 1 238 self.threads = threads 239 self.version = version 240 self.transfers = [] 241 self.src_basenames = {} 242 self.src_numpatterns = {} 243 244 assert version in (1, 2, 3) 245 246 self.tgt = tgt 247 if src is None: 248 src = EmptyImage() 249 self.src = src 250 251 # The updater code that installs the patch always uses 4k blocks. 252 assert tgt.blocksize == 4096 253 assert src.blocksize == 4096 254 255 # The range sets in each filemap should comprise a partition of 256 # the care map. 257 self.AssertPartition(src.care_map, src.file_map.values()) 258 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 259 260 def Compute(self, prefix): 261 # When looking for a source file to use as the diff input for a 262 # target file, we try: 263 # 1) an exact path match if available, otherwise 264 # 2) a exact basename match if available, otherwise 265 # 3) a basename match after all runs of digits are replaced by 266 # "#" if available, otherwise 267 # 4) we have no source for this target. 268 self.AbbreviateSourceNames() 269 self.FindTransfers() 270 271 # Find the ordering dependencies among transfers (this is O(n^2) 272 # in the number of transfers). 273 self.GenerateDigraph() 274 # Find a sequence of transfers that satisfies as many ordering 275 # dependencies as possible (heuristically). 276 self.FindVertexSequence() 277 # Fix up the ordering dependencies that the sequence didn't 278 # satisfy. 279 if self.version == 1: 280 self.RemoveBackwardEdges() 281 else: 282 self.ReverseBackwardEdges() 283 self.ImproveVertexSequence() 284 285 # Double-check our work. 286 self.AssertSequenceGood() 287 288 self.ComputePatches(prefix) 289 self.WriteTransfers(prefix) 290 291 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 292 data = source.ReadRangeSet(ranges) 293 ctx = sha1() 294 295 for p in data: 296 ctx.update(p) 297 298 return ctx.hexdigest() 299 300 def WriteTransfers(self, prefix): 301 out = [] 302 303 total = 0 304 performs_read = False 305 306 stashes = {} 307 stashed_blocks = 0 308 max_stashed_blocks = 0 309 310 free_stash_ids = [] 311 next_stash_id = 0 312 313 for xf in self.transfers: 314 315 if self.version < 2: 316 assert not xf.stash_before 317 assert not xf.use_stash 318 319 for s, sr in xf.stash_before: 320 assert s not in stashes 321 if free_stash_ids: 322 sid = heapq.heappop(free_stash_ids) 323 else: 324 sid = next_stash_id 325 next_stash_id += 1 326 stashes[s] = sid 327 stashed_blocks += sr.size() 328 if self.version == 2: 329 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 330 else: 331 sh = self.HashBlocks(self.src, sr) 332 if sh in stashes: 333 stashes[sh] += 1 334 else: 335 stashes[sh] = 1 336 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 337 338 if stashed_blocks > max_stashed_blocks: 339 max_stashed_blocks = stashed_blocks 340 341 free_string = [] 342 343 if self.version == 1: 344 src_str = xf.src_ranges.to_string_raw() 345 elif self.version >= 2: 346 347 # <# blocks> <src ranges> 348 # OR 349 # <# blocks> <src ranges> <src locs> <stash refs...> 350 # OR 351 # <# blocks> - <stash refs...> 352 353 size = xf.src_ranges.size() 354 src_str = [str(size)] 355 356 unstashed_src_ranges = xf.src_ranges 357 mapped_stashes = [] 358 for s, sr in xf.use_stash: 359 sid = stashes.pop(s) 360 stashed_blocks -= sr.size() 361 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 362 sh = self.HashBlocks(self.src, sr) 363 sr = xf.src_ranges.map_within(sr) 364 mapped_stashes.append(sr) 365 if self.version == 2: 366 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 367 else: 368 assert sh in stashes 369 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 370 stashes[sh] -= 1 371 if stashes[sh] == 0: 372 free_string.append("free %s\n" % (sh)) 373 stashes.pop(sh) 374 heapq.heappush(free_stash_ids, sid) 375 376 if unstashed_src_ranges: 377 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 378 if xf.use_stash: 379 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 380 src_str.insert(2, mapped_unstashed.to_string_raw()) 381 mapped_stashes.append(mapped_unstashed) 382 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 383 else: 384 src_str.insert(1, "-") 385 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 386 387 src_str = " ".join(src_str) 388 389 # all versions: 390 # zero <rangeset> 391 # new <rangeset> 392 # erase <rangeset> 393 # 394 # version 1: 395 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 396 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 397 # move <src rangeset> <tgt rangeset> 398 # 399 # version 2: 400 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 401 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 402 # move <tgt rangeset> <src_str> 403 # 404 # version 3: 405 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 406 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 407 # move hash <tgt rangeset> <src_str> 408 409 tgt_size = xf.tgt_ranges.size() 410 411 if xf.style == "new": 412 assert xf.tgt_ranges 413 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 414 total += tgt_size 415 elif xf.style == "move": 416 performs_read = True 417 assert xf.tgt_ranges 418 assert xf.src_ranges.size() == tgt_size 419 if xf.src_ranges != xf.tgt_ranges: 420 if self.version == 1: 421 out.append("%s %s %s\n" % ( 422 xf.style, 423 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 424 elif self.version == 2: 425 out.append("%s %s %s\n" % ( 426 xf.style, 427 xf.tgt_ranges.to_string_raw(), src_str)) 428 elif self.version >= 3: 429 # take into account automatic stashing of overlapping blocks 430 if xf.src_ranges.overlaps(xf.tgt_ranges): 431 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 432 if temp_stash_usage > max_stashed_blocks: 433 max_stashed_blocks = temp_stash_usage 434 435 out.append("%s %s %s %s\n" % ( 436 xf.style, 437 self.HashBlocks(self.tgt, xf.tgt_ranges), 438 xf.tgt_ranges.to_string_raw(), src_str)) 439 total += tgt_size 440 elif xf.style in ("bsdiff", "imgdiff"): 441 performs_read = True 442 assert xf.tgt_ranges 443 assert xf.src_ranges 444 if self.version == 1: 445 out.append("%s %d %d %s %s\n" % ( 446 xf.style, xf.patch_start, xf.patch_len, 447 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 448 elif self.version == 2: 449 out.append("%s %d %d %s %s\n" % ( 450 xf.style, xf.patch_start, xf.patch_len, 451 xf.tgt_ranges.to_string_raw(), src_str)) 452 elif self.version >= 3: 453 # take into account automatic stashing of overlapping blocks 454 if xf.src_ranges.overlaps(xf.tgt_ranges): 455 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 456 if temp_stash_usage > max_stashed_blocks: 457 max_stashed_blocks = temp_stash_usage 458 459 out.append("%s %d %d %s %s %s %s\n" % ( 460 xf.style, 461 xf.patch_start, xf.patch_len, 462 self.HashBlocks(self.src, xf.src_ranges), 463 self.HashBlocks(self.tgt, xf.tgt_ranges), 464 xf.tgt_ranges.to_string_raw(), src_str)) 465 total += tgt_size 466 elif xf.style == "zero": 467 assert xf.tgt_ranges 468 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 469 if to_zero: 470 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 471 total += to_zero.size() 472 else: 473 raise ValueError("unknown transfer style '%s'\n" % xf.style) 474 475 if free_string: 476 out.append("".join(free_string)) 477 478 # sanity check: abort if we're going to need more than 512 MB if 479 # stash space 480 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 481 482 # Zero out extended blocks as a workaround for bug 20881595. 483 if self.tgt.extended: 484 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 485 486 # We erase all the blocks on the partition that a) don't contain useful 487 # data in the new image and b) will not be touched by dm-verity. 488 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 489 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 490 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 491 if new_dontcare: 492 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 493 494 out.insert(0, "%d\n" % (self.version,)) # format version number 495 out.insert(1, str(total) + "\n") 496 if self.version >= 2: 497 # version 2 only: after the total block count, we give the number 498 # of stash slots needed, and the maximum size needed (in blocks) 499 out.insert(2, str(next_stash_id) + "\n") 500 out.insert(3, str(max_stashed_blocks) + "\n") 501 502 with open(prefix + ".transfer.list", "wb") as f: 503 for i in out: 504 f.write(i) 505 506 if self.version >= 2: 507 print("max stashed blocks: %d (%d bytes)\n" % ( 508 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 509 510 def ComputePatches(self, prefix): 511 print("Reticulating splines...") 512 diff_q = [] 513 patch_num = 0 514 with open(prefix + ".new.dat", "wb") as new_f: 515 for xf in self.transfers: 516 if xf.style == "zero": 517 pass 518 elif xf.style == "new": 519 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 520 new_f.write(piece) 521 elif xf.style == "diff": 522 src = self.src.ReadRangeSet(xf.src_ranges) 523 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 524 525 # We can't compare src and tgt directly because they may have 526 # the same content but be broken up into blocks differently, eg: 527 # 528 # ["he", "llo"] vs ["h", "ello"] 529 # 530 # We want those to compare equal, ideally without having to 531 # actually concatenate the strings (these may be tens of 532 # megabytes). 533 534 src_sha1 = sha1() 535 for p in src: 536 src_sha1.update(p) 537 tgt_sha1 = sha1() 538 tgt_size = 0 539 for p in tgt: 540 tgt_sha1.update(p) 541 tgt_size += len(p) 542 543 if src_sha1.digest() == tgt_sha1.digest(): 544 # These are identical; we don't need to generate a patch, 545 # just issue copy commands on the device. 546 xf.style = "move" 547 else: 548 # For files in zip format (eg, APKs, JARs, etc.) we would 549 # like to use imgdiff -z if possible (because it usually 550 # produces significantly smaller patches than bsdiff). 551 # This is permissible if: 552 # 553 # - the source and target files are monotonic (ie, the 554 # data is stored with blocks in increasing order), and 555 # - we haven't removed any blocks from the source set. 556 # 557 # If these conditions are satisfied then appending all the 558 # blocks in the set together in order will produce a valid 559 # zip file (plus possibly extra zeros in the last block), 560 # which is what imgdiff needs to operate. (imgdiff is 561 # fine with extra zeros at the end of the file.) 562 imgdiff = (xf.intact and 563 xf.tgt_name.split(".")[-1].lower() 564 in ("apk", "jar", "zip")) 565 xf.style = "imgdiff" if imgdiff else "bsdiff" 566 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 567 patch_num += 1 568 569 else: 570 assert False, "unknown style " + xf.style 571 572 if diff_q: 573 if self.threads > 1: 574 print("Computing patches (using %d threads)..." % (self.threads,)) 575 else: 576 print("Computing patches...") 577 diff_q.sort() 578 579 patches = [None] * patch_num 580 581 # TODO: Rewrite with multiprocessing.ThreadPool? 582 lock = threading.Lock() 583 def diff_worker(): 584 while True: 585 with lock: 586 if not diff_q: 587 return 588 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 589 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 590 size = len(patch) 591 with lock: 592 patches[patchnum] = (patch, xf) 593 print("%10d %10d (%6.2f%%) %7s %s" % ( 594 size, tgt_size, size * 100.0 / tgt_size, xf.style, 595 xf.tgt_name if xf.tgt_name == xf.src_name else ( 596 xf.tgt_name + " (from " + xf.src_name + ")"))) 597 598 threads = [threading.Thread(target=diff_worker) 599 for _ in range(self.threads)] 600 for th in threads: 601 th.start() 602 while threads: 603 threads.pop().join() 604 else: 605 patches = [] 606 607 p = 0 608 with open(prefix + ".patch.dat", "wb") as patch_f: 609 for patch, xf in patches: 610 xf.patch_start = p 611 xf.patch_len = len(patch) 612 patch_f.write(patch) 613 p += len(patch) 614 615 def AssertSequenceGood(self): 616 # Simulate the sequences of transfers we will output, and check that: 617 # - we never read a block after writing it, and 618 # - we write every block we care about exactly once. 619 620 # Start with no blocks having been touched yet. 621 touched = RangeSet() 622 623 # Imagine processing the transfers in order. 624 for xf in self.transfers: 625 # Check that the input blocks for this transfer haven't yet been touched. 626 627 x = xf.src_ranges 628 if self.version >= 2: 629 for _, sr in xf.use_stash: 630 x = x.subtract(sr) 631 632 assert not touched.overlaps(x) 633 # Check that the output blocks for this transfer haven't yet been touched. 634 assert not touched.overlaps(xf.tgt_ranges) 635 # Touch all the blocks written by this transfer. 636 touched = touched.union(xf.tgt_ranges) 637 638 # Check that we've written every target block. 639 assert touched == self.tgt.care_map 640 641 def ImproveVertexSequence(self): 642 print("Improving vertex order...") 643 644 # At this point our digraph is acyclic; we reversed any edges that 645 # were backwards in the heuristically-generated sequence. The 646 # previously-generated order is still acceptable, but we hope to 647 # find a better order that needs less memory for stashed data. 648 # Now we do a topological sort to generate a new vertex order, 649 # using a greedy algorithm to choose which vertex goes next 650 # whenever we have a choice. 651 652 # Make a copy of the edge set; this copy will get destroyed by the 653 # algorithm. 654 for xf in self.transfers: 655 xf.incoming = xf.goes_after.copy() 656 xf.outgoing = xf.goes_before.copy() 657 658 L = [] # the new vertex order 659 660 # S is the set of sources in the remaining graph; we always choose 661 # the one that leaves the least amount of stashed data after it's 662 # executed. 663 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 664 if not u.incoming] 665 heapq.heapify(S) 666 667 while S: 668 _, _, xf = heapq.heappop(S) 669 L.append(xf) 670 for u in xf.outgoing: 671 del u.incoming[xf] 672 if not u.incoming: 673 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 674 675 # if this fails then our graph had a cycle. 676 assert len(L) == len(self.transfers) 677 678 self.transfers = L 679 for i, xf in enumerate(L): 680 xf.order = i 681 682 def RemoveBackwardEdges(self): 683 print("Removing backward edges...") 684 in_order = 0 685 out_of_order = 0 686 lost_source = 0 687 688 for xf in self.transfers: 689 lost = 0 690 size = xf.src_ranges.size() 691 for u in xf.goes_before: 692 # xf should go before u 693 if xf.order < u.order: 694 # it does, hurray! 695 in_order += 1 696 else: 697 # it doesn't, boo. trim the blocks that u writes from xf's 698 # source, so that xf can go after u. 699 out_of_order += 1 700 assert xf.src_ranges.overlaps(u.tgt_ranges) 701 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 702 xf.intact = False 703 704 if xf.style == "diff" and not xf.src_ranges: 705 # nothing left to diff from; treat as new data 706 xf.style = "new" 707 708 lost = size - xf.src_ranges.size() 709 lost_source += lost 710 711 print((" %d/%d dependencies (%.2f%%) were violated; " 712 "%d source blocks removed.") % 713 (out_of_order, in_order + out_of_order, 714 (out_of_order * 100.0 / (in_order + out_of_order)) 715 if (in_order + out_of_order) else 0.0, 716 lost_source)) 717 718 def ReverseBackwardEdges(self): 719 print("Reversing backward edges...") 720 in_order = 0 721 out_of_order = 0 722 stashes = 0 723 stash_size = 0 724 725 for xf in self.transfers: 726 for u in xf.goes_before.copy(): 727 # xf should go before u 728 if xf.order < u.order: 729 # it does, hurray! 730 in_order += 1 731 else: 732 # it doesn't, boo. modify u to stash the blocks that it 733 # writes that xf wants to read, and then require u to go 734 # before xf. 735 out_of_order += 1 736 737 overlap = xf.src_ranges.intersect(u.tgt_ranges) 738 assert overlap 739 740 u.stash_before.append((stashes, overlap)) 741 xf.use_stash.append((stashes, overlap)) 742 stashes += 1 743 stash_size += overlap.size() 744 745 # reverse the edge direction; now xf must go after u 746 del xf.goes_before[u] 747 del u.goes_after[xf] 748 xf.goes_after[u] = None # value doesn't matter 749 u.goes_before[xf] = None 750 751 print((" %d/%d dependencies (%.2f%%) were violated; " 752 "%d source blocks stashed.") % 753 (out_of_order, in_order + out_of_order, 754 (out_of_order * 100.0 / (in_order + out_of_order)) 755 if (in_order + out_of_order) else 0.0, 756 stash_size)) 757 758 def FindVertexSequence(self): 759 print("Finding vertex sequence...") 760 761 # This is based on "A Fast & Effective Heuristic for the Feedback 762 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 763 # it as starting with the digraph G and moving all the vertices to 764 # be on a horizontal line in some order, trying to minimize the 765 # number of edges that end up pointing to the left. Left-pointing 766 # edges will get removed to turn the digraph into a DAG. In this 767 # case each edge has a weight which is the number of source blocks 768 # we'll lose if that edge is removed; we try to minimize the total 769 # weight rather than just the number of edges. 770 771 # Make a copy of the edge set; this copy will get destroyed by the 772 # algorithm. 773 for xf in self.transfers: 774 xf.incoming = xf.goes_after.copy() 775 xf.outgoing = xf.goes_before.copy() 776 777 # We use an OrderedDict instead of just a set so that the output 778 # is repeatable; otherwise it would depend on the hash values of 779 # the transfer objects. 780 G = OrderedDict() 781 for xf in self.transfers: 782 G[xf] = None 783 s1 = deque() # the left side of the sequence, built from left to right 784 s2 = deque() # the right side of the sequence, built from right to left 785 786 while G: 787 788 # Put all sinks at the end of the sequence. 789 while True: 790 sinks = [u for u in G if not u.outgoing] 791 if not sinks: 792 break 793 for u in sinks: 794 s2.appendleft(u) 795 del G[u] 796 for iu in u.incoming: 797 del iu.outgoing[u] 798 799 # Put all the sources at the beginning of the sequence. 800 while True: 801 sources = [u for u in G if not u.incoming] 802 if not sources: 803 break 804 for u in sources: 805 s1.append(u) 806 del G[u] 807 for iu in u.outgoing: 808 del iu.incoming[u] 809 810 if not G: 811 break 812 813 # Find the "best" vertex to put next. "Best" is the one that 814 # maximizes the net difference in source blocks saved we get by 815 # pretending it's a source rather than a sink. 816 817 max_d = None 818 best_u = None 819 for u in G: 820 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 821 if best_u is None or d > max_d: 822 max_d = d 823 best_u = u 824 825 u = best_u 826 s1.append(u) 827 del G[u] 828 for iu in u.outgoing: 829 del iu.incoming[u] 830 for iu in u.incoming: 831 del iu.outgoing[u] 832 833 # Now record the sequence in the 'order' field of each transfer, 834 # and by rearranging self.transfers to be in the chosen sequence. 835 836 new_transfers = [] 837 for x in itertools.chain(s1, s2): 838 x.order = len(new_transfers) 839 new_transfers.append(x) 840 del x.incoming 841 del x.outgoing 842 843 self.transfers = new_transfers 844 845 def GenerateDigraph(self): 846 print("Generating digraph...") 847 for a in self.transfers: 848 for b in self.transfers: 849 if a is b: 850 continue 851 852 # If the blocks written by A are read by B, then B needs to go before A. 853 i = a.tgt_ranges.intersect(b.src_ranges) 854 if i: 855 if b.src_name == "__ZERO": 856 # the cost of removing source blocks for the __ZERO domain 857 # is (nearly) zero. 858 size = 0 859 else: 860 size = i.size() 861 b.goes_before[a] = size 862 a.goes_after[b] = size 863 864 def FindTransfers(self): 865 empty = RangeSet() 866 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 867 if tgt_fn == "__ZERO": 868 # the special "__ZERO" domain is all the blocks not contained 869 # in any file and that are filled with zeros. We have a 870 # special transfer style for zero blocks. 871 src_ranges = self.src.file_map.get("__ZERO", empty) 872 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 873 "zero", self.transfers) 874 continue 875 876 elif tgt_fn == "__COPY": 877 # "__COPY" domain includes all the blocks not contained in any 878 # file and that need to be copied unconditionally to the target. 879 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 880 continue 881 882 elif tgt_fn in self.src.file_map: 883 # Look for an exact pathname match in the source. 884 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 885 "diff", self.transfers) 886 continue 887 888 b = os.path.basename(tgt_fn) 889 if b in self.src_basenames: 890 # Look for an exact basename match in the source. 891 src_fn = self.src_basenames[b] 892 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 893 "diff", self.transfers) 894 continue 895 896 b = re.sub("[0-9]+", "#", b) 897 if b in self.src_numpatterns: 898 # Look for a 'number pattern' match (a basename match after 899 # all runs of digits are replaced by "#"). (This is useful 900 # for .so files that contain version numbers in the filename 901 # that get bumped.) 902 src_fn = self.src_numpatterns[b] 903 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 904 "diff", self.transfers) 905 continue 906 907 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 908 909 def AbbreviateSourceNames(self): 910 for k in self.src.file_map.keys(): 911 b = os.path.basename(k) 912 self.src_basenames[b] = k 913 b = re.sub("[0-9]+", "#", b) 914 self.src_numpatterns[b] = k 915 916 @staticmethod 917 def AssertPartition(total, seq): 918 """Assert that all the RangeSets in 'seq' form a partition of the 919 'total' RangeSet (ie, they are nonintersecting and their union 920 equals 'total').""" 921 so_far = RangeSet() 922 for i in seq: 923 assert not so_far.overlaps(i) 924 so_far = so_far.union(i) 925 assert so_far == total 926