blockimgdiff.py revision 5fcaaeffc3d132f1223adf7ee8230b6815477af5
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import re 24import subprocess 25import threading 26import tempfile 27 28from rangelib import RangeSet 29 30 31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 32 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72 73class Image(object): 74 def ReadRangeSet(self, ranges): 75 raise NotImplementedError 76 77 def TotalSha1(self, include_clobbered_blocks=False): 78 raise NotImplementedError 79 80 81class EmptyImage(Image): 82 """A zero-length image.""" 83 blocksize = 4096 84 care_map = RangeSet() 85 clobbered_blocks = RangeSet() 86 total_blocks = 0 87 file_map = {} 88 def ReadRangeSet(self, ranges): 89 return () 90 def TotalSha1(self, include_clobbered_blocks=False): 91 # EmptyImage always carries empty clobbered_blocks, so 92 # include_clobbered_blocks can be ignored. 93 assert self.clobbered_blocks.size() == 0 94 return sha1().hexdigest() 95 96 97class DataImage(Image): 98 """An image wrapped around a single string of data.""" 99 100 def __init__(self, data, trim=False, pad=False): 101 self.data = data 102 self.blocksize = 4096 103 104 assert not (trim and pad) 105 106 partial = len(self.data) % self.blocksize 107 if partial > 0: 108 if trim: 109 self.data = self.data[:-partial] 110 elif pad: 111 self.data += '\0' * (self.blocksize - partial) 112 else: 113 raise ValueError(("data for DataImage must be multiple of %d bytes " 114 "unless trim or pad is specified") % 115 (self.blocksize,)) 116 117 assert len(self.data) % self.blocksize == 0 118 119 self.total_blocks = len(self.data) / self.blocksize 120 self.care_map = RangeSet(data=(0, self.total_blocks)) 121 self.clobbered_blocks = RangeSet() 122 123 zero_blocks = [] 124 nonzero_blocks = [] 125 reference = '\0' * self.blocksize 126 127 for i in range(self.total_blocks): 128 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 129 if d == reference: 130 zero_blocks.append(i) 131 zero_blocks.append(i+1) 132 else: 133 nonzero_blocks.append(i) 134 nonzero_blocks.append(i+1) 135 136 self.file_map = {"__ZERO": RangeSet(zero_blocks), 137 "__NONZERO": RangeSet(nonzero_blocks)} 138 139 def ReadRangeSet(self, ranges): 140 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 141 142 def TotalSha1(self, include_clobbered_blocks=False): 143 # DataImage always carries empty clobbered_blocks, so 144 # include_clobbered_blocks can be ignored. 145 assert self.clobbered_blocks.size() == 0 146 return sha1(self.data).hexdigest() 147 148 149class Transfer(object): 150 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 151 self.tgt_name = tgt_name 152 self.src_name = src_name 153 self.tgt_ranges = tgt_ranges 154 self.src_ranges = src_ranges 155 self.style = style 156 self.intact = (getattr(tgt_ranges, "monotonic", False) and 157 getattr(src_ranges, "monotonic", False)) 158 159 # We use OrderedDict rather than dict so that the output is repeatable; 160 # otherwise it would depend on the hash values of the Transfer objects. 161 self.goes_before = OrderedDict() 162 self.goes_after = OrderedDict() 163 164 self.stash_before = [] 165 self.use_stash = [] 166 167 self.id = len(by_id) 168 by_id.append(self) 169 170 def NetStashChange(self): 171 return (sum(sr.size() for (_, sr) in self.stash_before) - 172 sum(sr.size() for (_, sr) in self.use_stash)) 173 174 def __str__(self): 175 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 176 " to " + str(self.tgt_ranges) + ">") 177 178 179# BlockImageDiff works on two image objects. An image object is 180# anything that provides the following attributes: 181# 182# blocksize: the size in bytes of a block, currently must be 4096. 183# 184# total_blocks: the total size of the partition/image, in blocks. 185# 186# care_map: a RangeSet containing which blocks (in the range [0, 187# total_blocks) we actually care about; i.e. which blocks contain 188# data. 189# 190# file_map: a dict that partitions the blocks contained in care_map 191# into smaller domains that are useful for doing diffs on. 192# (Typically a domain is a file, and the key in file_map is the 193# pathname.) 194# 195# clobbered_blocks: a RangeSet containing which blocks contain data 196# but may be altered by the FS. They need to be excluded when 197# verifying the partition integrity. 198# 199# ReadRangeSet(): a function that takes a RangeSet and returns the 200# data contained in the image blocks of that RangeSet. The data 201# is returned as a list or tuple of strings; concatenating the 202# elements together should produce the requested data. 203# Implementations are free to break up the data into list/tuple 204# elements in any way that is convenient. 205# 206# TotalSha1(): a function that returns (as a hex string) the SHA-1 207# hash of all the data in the image (ie, all the blocks in the 208# care_map minus clobbered_blocks, or including the clobbered 209# blocks if include_clobbered_blocks is True). 210# 211# When creating a BlockImageDiff, the src image may be None, in which 212# case the list of transfers produced will never read from the 213# original image. 214 215class BlockImageDiff(object): 216 def __init__(self, tgt, src=None, threads=None, version=3): 217 if threads is None: 218 threads = multiprocessing.cpu_count() // 2 219 if threads == 0: 220 threads = 1 221 self.threads = threads 222 self.version = version 223 self.transfers = [] 224 self.src_basenames = {} 225 self.src_numpatterns = {} 226 227 assert version in (1, 2, 3) 228 229 self.tgt = tgt 230 if src is None: 231 src = EmptyImage() 232 self.src = src 233 234 # The updater code that installs the patch always uses 4k blocks. 235 assert tgt.blocksize == 4096 236 assert src.blocksize == 4096 237 238 # The range sets in each filemap should comprise a partition of 239 # the care map. 240 self.AssertPartition(src.care_map, src.file_map.values()) 241 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 242 243 def Compute(self, prefix): 244 # When looking for a source file to use as the diff input for a 245 # target file, we try: 246 # 1) an exact path match if available, otherwise 247 # 2) a exact basename match if available, otherwise 248 # 3) a basename match after all runs of digits are replaced by 249 # "#" if available, otherwise 250 # 4) we have no source for this target. 251 self.AbbreviateSourceNames() 252 self.FindTransfers() 253 254 # Find the ordering dependencies among transfers (this is O(n^2) 255 # in the number of transfers). 256 self.GenerateDigraph() 257 # Find a sequence of transfers that satisfies as many ordering 258 # dependencies as possible (heuristically). 259 self.FindVertexSequence() 260 # Fix up the ordering dependencies that the sequence didn't 261 # satisfy. 262 if self.version == 1: 263 self.RemoveBackwardEdges() 264 else: 265 self.ReverseBackwardEdges() 266 self.ImproveVertexSequence() 267 268 # Double-check our work. 269 self.AssertSequenceGood() 270 271 self.ComputePatches(prefix) 272 self.WriteTransfers(prefix) 273 274 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 275 data = source.ReadRangeSet(ranges) 276 ctx = sha1() 277 278 for p in data: 279 ctx.update(p) 280 281 return ctx.hexdigest() 282 283 def WriteTransfers(self, prefix): 284 out = [] 285 286 total = 0 287 performs_read = False 288 289 stashes = {} 290 stashed_blocks = 0 291 max_stashed_blocks = 0 292 293 free_stash_ids = [] 294 next_stash_id = 0 295 296 for xf in self.transfers: 297 298 if self.version < 2: 299 assert not xf.stash_before 300 assert not xf.use_stash 301 302 for s, sr in xf.stash_before: 303 assert s not in stashes 304 if free_stash_ids: 305 sid = heapq.heappop(free_stash_ids) 306 else: 307 sid = next_stash_id 308 next_stash_id += 1 309 stashes[s] = sid 310 stashed_blocks += sr.size() 311 if self.version == 2: 312 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 313 else: 314 sh = self.HashBlocks(self.src, sr) 315 if sh in stashes: 316 stashes[sh] += 1 317 else: 318 stashes[sh] = 1 319 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 320 321 if stashed_blocks > max_stashed_blocks: 322 max_stashed_blocks = stashed_blocks 323 324 free_string = [] 325 326 if self.version == 1: 327 src_str = xf.src_ranges.to_string_raw() 328 elif self.version >= 2: 329 330 # <# blocks> <src ranges> 331 # OR 332 # <# blocks> <src ranges> <src locs> <stash refs...> 333 # OR 334 # <# blocks> - <stash refs...> 335 336 size = xf.src_ranges.size() 337 src_str = [str(size)] 338 339 unstashed_src_ranges = xf.src_ranges 340 mapped_stashes = [] 341 for s, sr in xf.use_stash: 342 sid = stashes.pop(s) 343 stashed_blocks -= sr.size() 344 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 345 sh = self.HashBlocks(self.src, sr) 346 sr = xf.src_ranges.map_within(sr) 347 mapped_stashes.append(sr) 348 if self.version == 2: 349 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 350 else: 351 assert sh in stashes 352 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 353 stashes[sh] -= 1 354 if stashes[sh] == 0: 355 free_string.append("free %s\n" % (sh)) 356 stashes.pop(sh) 357 heapq.heappush(free_stash_ids, sid) 358 359 if unstashed_src_ranges: 360 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 361 if xf.use_stash: 362 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 363 src_str.insert(2, mapped_unstashed.to_string_raw()) 364 mapped_stashes.append(mapped_unstashed) 365 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 366 else: 367 src_str.insert(1, "-") 368 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 369 370 src_str = " ".join(src_str) 371 372 # all versions: 373 # zero <rangeset> 374 # new <rangeset> 375 # erase <rangeset> 376 # 377 # version 1: 378 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 379 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 380 # move <src rangeset> <tgt rangeset> 381 # 382 # version 2: 383 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 384 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 385 # move <tgt rangeset> <src_str> 386 # 387 # version 3: 388 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 389 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 390 # move hash <tgt rangeset> <src_str> 391 392 tgt_size = xf.tgt_ranges.size() 393 394 if xf.style == "new": 395 assert xf.tgt_ranges 396 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 397 total += tgt_size 398 elif xf.style == "move": 399 performs_read = True 400 assert xf.tgt_ranges 401 assert xf.src_ranges.size() == tgt_size 402 if xf.src_ranges != xf.tgt_ranges: 403 if self.version == 1: 404 out.append("%s %s %s\n" % ( 405 xf.style, 406 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 407 elif self.version == 2: 408 out.append("%s %s %s\n" % ( 409 xf.style, 410 xf.tgt_ranges.to_string_raw(), src_str)) 411 elif self.version >= 3: 412 # take into account automatic stashing of overlapping blocks 413 if xf.src_ranges.overlaps(xf.tgt_ranges): 414 temp_stash_usage = stashed_blocks + xf.src_ranges.size(); 415 if temp_stash_usage > max_stashed_blocks: 416 max_stashed_blocks = temp_stash_usage 417 418 out.append("%s %s %s %s\n" % ( 419 xf.style, 420 self.HashBlocks(self.tgt, xf.tgt_ranges), 421 xf.tgt_ranges.to_string_raw(), src_str)) 422 total += tgt_size 423 elif xf.style in ("bsdiff", "imgdiff"): 424 performs_read = True 425 assert xf.tgt_ranges 426 assert xf.src_ranges 427 if self.version == 1: 428 out.append("%s %d %d %s %s\n" % ( 429 xf.style, xf.patch_start, xf.patch_len, 430 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 431 elif self.version == 2: 432 out.append("%s %d %d %s %s\n" % ( 433 xf.style, xf.patch_start, xf.patch_len, 434 xf.tgt_ranges.to_string_raw(), src_str)) 435 elif self.version >= 3: 436 # take into account automatic stashing of overlapping blocks 437 if xf.src_ranges.overlaps(xf.tgt_ranges): 438 temp_stash_usage = stashed_blocks + xf.src_ranges.size(); 439 if temp_stash_usage > max_stashed_blocks: 440 max_stashed_blocks = temp_stash_usage 441 442 out.append("%s %d %d %s %s %s %s\n" % ( 443 xf.style, 444 xf.patch_start, xf.patch_len, 445 self.HashBlocks(self.src, xf.src_ranges), 446 self.HashBlocks(self.tgt, xf.tgt_ranges), 447 xf.tgt_ranges.to_string_raw(), src_str)) 448 total += tgt_size 449 elif xf.style == "zero": 450 assert xf.tgt_ranges 451 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 452 if to_zero: 453 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 454 total += to_zero.size() 455 else: 456 raise ValueError("unknown transfer style '%s'\n" % xf.style) 457 458 if free_string: 459 out.append("".join(free_string)) 460 461 # sanity check: abort if we're going to need more than 512 MB if 462 # stash space 463 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 464 465 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 466 if performs_read: 467 # if some of the original data is used, then at the end we'll 468 # erase all the blocks on the partition that don't contain data 469 # in the new image. 470 new_dontcare = all_tgt.subtract(self.tgt.care_map) 471 if new_dontcare: 472 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 473 else: 474 # if nothing is read (ie, this is a full OTA), then we can start 475 # by erasing the entire partition. 476 out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),)) 477 478 out.insert(0, "%d\n" % (self.version,)) # format version number 479 out.insert(1, str(total) + "\n") 480 if self.version >= 2: 481 # version 2 only: after the total block count, we give the number 482 # of stash slots needed, and the maximum size needed (in blocks) 483 out.insert(2, str(next_stash_id) + "\n") 484 out.insert(3, str(max_stashed_blocks) + "\n") 485 486 with open(prefix + ".transfer.list", "wb") as f: 487 for i in out: 488 f.write(i) 489 490 if self.version >= 2: 491 print("max stashed blocks: %d (%d bytes)\n" % ( 492 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 493 494 def ComputePatches(self, prefix): 495 print("Reticulating splines...") 496 diff_q = [] 497 patch_num = 0 498 with open(prefix + ".new.dat", "wb") as new_f: 499 for xf in self.transfers: 500 if xf.style == "zero": 501 pass 502 elif xf.style == "new": 503 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 504 new_f.write(piece) 505 elif xf.style == "diff": 506 src = self.src.ReadRangeSet(xf.src_ranges) 507 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 508 509 # We can't compare src and tgt directly because they may have 510 # the same content but be broken up into blocks differently, eg: 511 # 512 # ["he", "llo"] vs ["h", "ello"] 513 # 514 # We want those to compare equal, ideally without having to 515 # actually concatenate the strings (these may be tens of 516 # megabytes). 517 518 src_sha1 = sha1() 519 for p in src: 520 src_sha1.update(p) 521 tgt_sha1 = sha1() 522 tgt_size = 0 523 for p in tgt: 524 tgt_sha1.update(p) 525 tgt_size += len(p) 526 527 if src_sha1.digest() == tgt_sha1.digest(): 528 # These are identical; we don't need to generate a patch, 529 # just issue copy commands on the device. 530 xf.style = "move" 531 else: 532 # For files in zip format (eg, APKs, JARs, etc.) we would 533 # like to use imgdiff -z if possible (because it usually 534 # produces significantly smaller patches than bsdiff). 535 # This is permissible if: 536 # 537 # - the source and target files are monotonic (ie, the 538 # data is stored with blocks in increasing order), and 539 # - we haven't removed any blocks from the source set. 540 # 541 # If these conditions are satisfied then appending all the 542 # blocks in the set together in order will produce a valid 543 # zip file (plus possibly extra zeros in the last block), 544 # which is what imgdiff needs to operate. (imgdiff is 545 # fine with extra zeros at the end of the file.) 546 imgdiff = (xf.intact and 547 xf.tgt_name.split(".")[-1].lower() 548 in ("apk", "jar", "zip")) 549 xf.style = "imgdiff" if imgdiff else "bsdiff" 550 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 551 patch_num += 1 552 553 else: 554 assert False, "unknown style " + xf.style 555 556 if diff_q: 557 if self.threads > 1: 558 print("Computing patches (using %d threads)..." % (self.threads,)) 559 else: 560 print("Computing patches...") 561 diff_q.sort() 562 563 patches = [None] * patch_num 564 565 # TODO: Rewrite with multiprocessing.ThreadPool? 566 lock = threading.Lock() 567 def diff_worker(): 568 while True: 569 with lock: 570 if not diff_q: 571 return 572 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 573 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 574 size = len(patch) 575 with lock: 576 patches[patchnum] = (patch, xf) 577 print("%10d %10d (%6.2f%%) %7s %s" % ( 578 size, tgt_size, size * 100.0 / tgt_size, xf.style, 579 xf.tgt_name if xf.tgt_name == xf.src_name else ( 580 xf.tgt_name + " (from " + xf.src_name + ")"))) 581 582 threads = [threading.Thread(target=diff_worker) 583 for _ in range(self.threads)] 584 for th in threads: 585 th.start() 586 while threads: 587 threads.pop().join() 588 else: 589 patches = [] 590 591 p = 0 592 with open(prefix + ".patch.dat", "wb") as patch_f: 593 for patch, xf in patches: 594 xf.patch_start = p 595 xf.patch_len = len(patch) 596 patch_f.write(patch) 597 p += len(patch) 598 599 def AssertSequenceGood(self): 600 # Simulate the sequences of transfers we will output, and check that: 601 # - we never read a block after writing it, and 602 # - we write every block we care about exactly once. 603 604 # Start with no blocks having been touched yet. 605 touched = RangeSet() 606 607 # Imagine processing the transfers in order. 608 for xf in self.transfers: 609 # Check that the input blocks for this transfer haven't yet been touched. 610 611 x = xf.src_ranges 612 if self.version >= 2: 613 for _, sr in xf.use_stash: 614 x = x.subtract(sr) 615 616 assert not touched.overlaps(x) 617 # Check that the output blocks for this transfer haven't yet been touched. 618 assert not touched.overlaps(xf.tgt_ranges) 619 # Touch all the blocks written by this transfer. 620 touched = touched.union(xf.tgt_ranges) 621 622 # Check that we've written every target block. 623 assert touched == self.tgt.care_map 624 625 def ImproveVertexSequence(self): 626 print("Improving vertex order...") 627 628 # At this point our digraph is acyclic; we reversed any edges that 629 # were backwards in the heuristically-generated sequence. The 630 # previously-generated order is still acceptable, but we hope to 631 # find a better order that needs less memory for stashed data. 632 # Now we do a topological sort to generate a new vertex order, 633 # using a greedy algorithm to choose which vertex goes next 634 # whenever we have a choice. 635 636 # Make a copy of the edge set; this copy will get destroyed by the 637 # algorithm. 638 for xf in self.transfers: 639 xf.incoming = xf.goes_after.copy() 640 xf.outgoing = xf.goes_before.copy() 641 642 L = [] # the new vertex order 643 644 # S is the set of sources in the remaining graph; we always choose 645 # the one that leaves the least amount of stashed data after it's 646 # executed. 647 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 648 if not u.incoming] 649 heapq.heapify(S) 650 651 while S: 652 _, _, xf = heapq.heappop(S) 653 L.append(xf) 654 for u in xf.outgoing: 655 del u.incoming[xf] 656 if not u.incoming: 657 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 658 659 # if this fails then our graph had a cycle. 660 assert len(L) == len(self.transfers) 661 662 self.transfers = L 663 for i, xf in enumerate(L): 664 xf.order = i 665 666 def RemoveBackwardEdges(self): 667 print("Removing backward edges...") 668 in_order = 0 669 out_of_order = 0 670 lost_source = 0 671 672 for xf in self.transfers: 673 lost = 0 674 size = xf.src_ranges.size() 675 for u in xf.goes_before: 676 # xf should go before u 677 if xf.order < u.order: 678 # it does, hurray! 679 in_order += 1 680 else: 681 # it doesn't, boo. trim the blocks that u writes from xf's 682 # source, so that xf can go after u. 683 out_of_order += 1 684 assert xf.src_ranges.overlaps(u.tgt_ranges) 685 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 686 xf.intact = False 687 688 if xf.style == "diff" and not xf.src_ranges: 689 # nothing left to diff from; treat as new data 690 xf.style = "new" 691 692 lost = size - xf.src_ranges.size() 693 lost_source += lost 694 695 print((" %d/%d dependencies (%.2f%%) were violated; " 696 "%d source blocks removed.") % 697 (out_of_order, in_order + out_of_order, 698 (out_of_order * 100.0 / (in_order + out_of_order)) 699 if (in_order + out_of_order) else 0.0, 700 lost_source)) 701 702 def ReverseBackwardEdges(self): 703 print("Reversing backward edges...") 704 in_order = 0 705 out_of_order = 0 706 stashes = 0 707 stash_size = 0 708 709 for xf in self.transfers: 710 for u in xf.goes_before.copy(): 711 # xf should go before u 712 if xf.order < u.order: 713 # it does, hurray! 714 in_order += 1 715 else: 716 # it doesn't, boo. modify u to stash the blocks that it 717 # writes that xf wants to read, and then require u to go 718 # before xf. 719 out_of_order += 1 720 721 overlap = xf.src_ranges.intersect(u.tgt_ranges) 722 assert overlap 723 724 u.stash_before.append((stashes, overlap)) 725 xf.use_stash.append((stashes, overlap)) 726 stashes += 1 727 stash_size += overlap.size() 728 729 # reverse the edge direction; now xf must go after u 730 del xf.goes_before[u] 731 del u.goes_after[xf] 732 xf.goes_after[u] = None # value doesn't matter 733 u.goes_before[xf] = None 734 735 print((" %d/%d dependencies (%.2f%%) were violated; " 736 "%d source blocks stashed.") % 737 (out_of_order, in_order + out_of_order, 738 (out_of_order * 100.0 / (in_order + out_of_order)) 739 if (in_order + out_of_order) else 0.0, 740 stash_size)) 741 742 def FindVertexSequence(self): 743 print("Finding vertex sequence...") 744 745 # This is based on "A Fast & Effective Heuristic for the Feedback 746 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 747 # it as starting with the digraph G and moving all the vertices to 748 # be on a horizontal line in some order, trying to minimize the 749 # number of edges that end up pointing to the left. Left-pointing 750 # edges will get removed to turn the digraph into a DAG. In this 751 # case each edge has a weight which is the number of source blocks 752 # we'll lose if that edge is removed; we try to minimize the total 753 # weight rather than just the number of edges. 754 755 # Make a copy of the edge set; this copy will get destroyed by the 756 # algorithm. 757 for xf in self.transfers: 758 xf.incoming = xf.goes_after.copy() 759 xf.outgoing = xf.goes_before.copy() 760 761 # We use an OrderedDict instead of just a set so that the output 762 # is repeatable; otherwise it would depend on the hash values of 763 # the transfer objects. 764 G = OrderedDict() 765 for xf in self.transfers: 766 G[xf] = None 767 s1 = deque() # the left side of the sequence, built from left to right 768 s2 = deque() # the right side of the sequence, built from right to left 769 770 while G: 771 772 # Put all sinks at the end of the sequence. 773 while True: 774 sinks = [u for u in G if not u.outgoing] 775 if not sinks: 776 break 777 for u in sinks: 778 s2.appendleft(u) 779 del G[u] 780 for iu in u.incoming: 781 del iu.outgoing[u] 782 783 # Put all the sources at the beginning of the sequence. 784 while True: 785 sources = [u for u in G if not u.incoming] 786 if not sources: 787 break 788 for u in sources: 789 s1.append(u) 790 del G[u] 791 for iu in u.outgoing: 792 del iu.incoming[u] 793 794 if not G: 795 break 796 797 # Find the "best" vertex to put next. "Best" is the one that 798 # maximizes the net difference in source blocks saved we get by 799 # pretending it's a source rather than a sink. 800 801 max_d = None 802 best_u = None 803 for u in G: 804 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 805 if best_u is None or d > max_d: 806 max_d = d 807 best_u = u 808 809 u = best_u 810 s1.append(u) 811 del G[u] 812 for iu in u.outgoing: 813 del iu.incoming[u] 814 for iu in u.incoming: 815 del iu.outgoing[u] 816 817 # Now record the sequence in the 'order' field of each transfer, 818 # and by rearranging self.transfers to be in the chosen sequence. 819 820 new_transfers = [] 821 for x in itertools.chain(s1, s2): 822 x.order = len(new_transfers) 823 new_transfers.append(x) 824 del x.incoming 825 del x.outgoing 826 827 self.transfers = new_transfers 828 829 def GenerateDigraph(self): 830 print("Generating digraph...") 831 for a in self.transfers: 832 for b in self.transfers: 833 if a is b: 834 continue 835 836 # If the blocks written by A are read by B, then B needs to go before A. 837 i = a.tgt_ranges.intersect(b.src_ranges) 838 if i: 839 if b.src_name == "__ZERO": 840 # the cost of removing source blocks for the __ZERO domain 841 # is (nearly) zero. 842 size = 0 843 else: 844 size = i.size() 845 b.goes_before[a] = size 846 a.goes_after[b] = size 847 848 def FindTransfers(self): 849 empty = RangeSet() 850 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 851 if tgt_fn == "__ZERO": 852 # the special "__ZERO" domain is all the blocks not contained 853 # in any file and that are filled with zeros. We have a 854 # special transfer style for zero blocks. 855 src_ranges = self.src.file_map.get("__ZERO", empty) 856 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 857 "zero", self.transfers) 858 continue 859 860 elif tgt_fn == "__COPY": 861 # "__COPY" domain includes all the blocks not contained in any 862 # file and that need to be copied unconditionally to the target. 863 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 864 continue 865 866 elif tgt_fn in self.src.file_map: 867 # Look for an exact pathname match in the source. 868 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 869 "diff", self.transfers) 870 continue 871 872 b = os.path.basename(tgt_fn) 873 if b in self.src_basenames: 874 # Look for an exact basename match in the source. 875 src_fn = self.src_basenames[b] 876 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 877 "diff", self.transfers) 878 continue 879 880 b = re.sub("[0-9]+", "#", b) 881 if b in self.src_numpatterns: 882 # Look for a 'number pattern' match (a basename match after 883 # all runs of digits are replaced by "#"). (This is useful 884 # for .so files that contain version numbers in the filename 885 # that get bumped.) 886 src_fn = self.src_numpatterns[b] 887 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 888 "diff", self.transfers) 889 continue 890 891 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 892 893 def AbbreviateSourceNames(self): 894 for k in self.src.file_map.keys(): 895 b = os.path.basename(k) 896 self.src_basenames[b] = k 897 b = re.sub("[0-9]+", "#", b) 898 self.src_numpatterns[b] = k 899 900 @staticmethod 901 def AssertPartition(total, seq): 902 """Assert that all the RangeSets in 'seq' form a partition of the 903 'total' RangeSet (ie, they are nonintersecting and their union 904 equals 'total').""" 905 so_far = RangeSet() 906 for i in seq: 907 assert not so_far.overlaps(i) 908 so_far = so_far.union(i) 909 assert so_far == total 910