blockimgdiff.py revision 5ece99d64efe82fc1ca97a96f079ce69ea588a24
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import re 24import subprocess 25import threading 26import tempfile 27 28from rangelib import RangeSet 29 30 31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 32 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72 73class Image(object): 74 def ReadRangeSet(self, ranges): 75 raise NotImplementedError 76 77 def TotalSha1(self): 78 raise NotImplementedError 79 80 81class EmptyImage(Image): 82 """A zero-length image.""" 83 blocksize = 4096 84 care_map = RangeSet() 85 clobbered_blocks = RangeSet() 86 total_blocks = 0 87 file_map = {} 88 def ReadRangeSet(self, ranges): 89 return () 90 def TotalSha1(self): 91 return sha1().hexdigest() 92 93 94class DataImage(Image): 95 """An image wrapped around a single string of data.""" 96 97 def __init__(self, data, trim=False, pad=False): 98 self.data = data 99 self.blocksize = 4096 100 101 assert not (trim and pad) 102 103 partial = len(self.data) % self.blocksize 104 if partial > 0: 105 if trim: 106 self.data = self.data[:-partial] 107 elif pad: 108 self.data += '\0' * (self.blocksize - partial) 109 else: 110 raise ValueError(("data for DataImage must be multiple of %d bytes " 111 "unless trim or pad is specified") % 112 (self.blocksize,)) 113 114 assert len(self.data) % self.blocksize == 0 115 116 self.total_blocks = len(self.data) / self.blocksize 117 self.care_map = RangeSet(data=(0, self.total_blocks)) 118 self.clobbered_blocks = RangeSet() 119 120 zero_blocks = [] 121 nonzero_blocks = [] 122 reference = '\0' * self.blocksize 123 124 for i in range(self.total_blocks): 125 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 126 if d == reference: 127 zero_blocks.append(i) 128 zero_blocks.append(i+1) 129 else: 130 nonzero_blocks.append(i) 131 nonzero_blocks.append(i+1) 132 133 self.file_map = {"__ZERO": RangeSet(zero_blocks), 134 "__NONZERO": RangeSet(nonzero_blocks)} 135 136 def ReadRangeSet(self, ranges): 137 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 138 139 def TotalSha1(self): 140 # DataImage always carries empty clobbered_blocks. 141 assert self.clobbered_blocks.size() == 0 142 return sha1(self.data).hexdigest() 143 144 145class Transfer(object): 146 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 147 self.tgt_name = tgt_name 148 self.src_name = src_name 149 self.tgt_ranges = tgt_ranges 150 self.src_ranges = src_ranges 151 self.style = style 152 self.intact = (getattr(tgt_ranges, "monotonic", False) and 153 getattr(src_ranges, "monotonic", False)) 154 155 # We use OrderedDict rather than dict so that the output is repeatable; 156 # otherwise it would depend on the hash values of the Transfer objects. 157 self.goes_before = OrderedDict() 158 self.goes_after = OrderedDict() 159 160 self.stash_before = [] 161 self.use_stash = [] 162 163 self.id = len(by_id) 164 by_id.append(self) 165 166 def NetStashChange(self): 167 return (sum(sr.size() for (_, sr) in self.stash_before) - 168 sum(sr.size() for (_, sr) in self.use_stash)) 169 170 def __str__(self): 171 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 172 " to " + str(self.tgt_ranges) + ">") 173 174 175# BlockImageDiff works on two image objects. An image object is 176# anything that provides the following attributes: 177# 178# blocksize: the size in bytes of a block, currently must be 4096. 179# 180# total_blocks: the total size of the partition/image, in blocks. 181# 182# care_map: a RangeSet containing which blocks (in the range [0, 183# total_blocks) we actually care about; i.e. which blocks contain 184# data. 185# 186# file_map: a dict that partitions the blocks contained in care_map 187# into smaller domains that are useful for doing diffs on. 188# (Typically a domain is a file, and the key in file_map is the 189# pathname.) 190# 191# clobbered_blocks: a RangeSet containing which blocks contain data 192# but may be altered by the FS. They need to be excluded when 193# verifying the partition integrity. 194# 195# ReadRangeSet(): a function that takes a RangeSet and returns the 196# data contained in the image blocks of that RangeSet. The data 197# is returned as a list or tuple of strings; concatenating the 198# elements together should produce the requested data. 199# Implementations are free to break up the data into list/tuple 200# elements in any way that is convenient. 201# 202# TotalSha1(): a function that returns (as a hex string) the SHA-1 203# hash of all the data in the image (ie, all the blocks in the 204# care_map minus clobbered_blocks). 205# 206# When creating a BlockImageDiff, the src image may be None, in which 207# case the list of transfers produced will never read from the 208# original image. 209 210class BlockImageDiff(object): 211 def __init__(self, tgt, src=None, threads=None, version=3): 212 if threads is None: 213 threads = multiprocessing.cpu_count() // 2 214 if threads == 0: 215 threads = 1 216 self.threads = threads 217 self.version = version 218 self.transfers = [] 219 self.src_basenames = {} 220 self.src_numpatterns = {} 221 222 assert version in (1, 2, 3) 223 224 self.tgt = tgt 225 if src is None: 226 src = EmptyImage() 227 self.src = src 228 229 # The updater code that installs the patch always uses 4k blocks. 230 assert tgt.blocksize == 4096 231 assert src.blocksize == 4096 232 233 # The range sets in each filemap should comprise a partition of 234 # the care map. 235 self.AssertPartition(src.care_map, src.file_map.values()) 236 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 237 238 def Compute(self, prefix): 239 # When looking for a source file to use as the diff input for a 240 # target file, we try: 241 # 1) an exact path match if available, otherwise 242 # 2) a exact basename match if available, otherwise 243 # 3) a basename match after all runs of digits are replaced by 244 # "#" if available, otherwise 245 # 4) we have no source for this target. 246 self.AbbreviateSourceNames() 247 self.FindTransfers() 248 249 # Find the ordering dependencies among transfers (this is O(n^2) 250 # in the number of transfers). 251 self.GenerateDigraph() 252 # Find a sequence of transfers that satisfies as many ordering 253 # dependencies as possible (heuristically). 254 self.FindVertexSequence() 255 # Fix up the ordering dependencies that the sequence didn't 256 # satisfy. 257 if self.version == 1: 258 self.RemoveBackwardEdges() 259 else: 260 self.ReverseBackwardEdges() 261 self.ImproveVertexSequence() 262 263 # Double-check our work. 264 self.AssertSequenceGood() 265 266 self.ComputePatches(prefix) 267 self.WriteTransfers(prefix) 268 269 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 270 data = source.ReadRangeSet(ranges) 271 ctx = sha1() 272 273 for p in data: 274 ctx.update(p) 275 276 return ctx.hexdigest() 277 278 def WriteTransfers(self, prefix): 279 out = [] 280 281 total = 0 282 performs_read = False 283 284 stashes = {} 285 stashed_blocks = 0 286 max_stashed_blocks = 0 287 288 free_stash_ids = [] 289 next_stash_id = 0 290 291 for xf in self.transfers: 292 293 if self.version < 2: 294 assert not xf.stash_before 295 assert not xf.use_stash 296 297 for s, sr in xf.stash_before: 298 assert s not in stashes 299 if free_stash_ids: 300 sid = heapq.heappop(free_stash_ids) 301 else: 302 sid = next_stash_id 303 next_stash_id += 1 304 stashes[s] = sid 305 stashed_blocks += sr.size() 306 if self.version == 2: 307 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 308 else: 309 sh = self.HashBlocks(self.src, sr) 310 if sh in stashes: 311 stashes[sh] += 1 312 else: 313 stashes[sh] = 1 314 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 315 316 if stashed_blocks > max_stashed_blocks: 317 max_stashed_blocks = stashed_blocks 318 319 free_string = [] 320 321 if self.version == 1: 322 src_str = xf.src_ranges.to_string_raw() 323 elif self.version >= 2: 324 325 # <# blocks> <src ranges> 326 # OR 327 # <# blocks> <src ranges> <src locs> <stash refs...> 328 # OR 329 # <# blocks> - <stash refs...> 330 331 size = xf.src_ranges.size() 332 src_str = [str(size)] 333 334 unstashed_src_ranges = xf.src_ranges 335 mapped_stashes = [] 336 for s, sr in xf.use_stash: 337 sid = stashes.pop(s) 338 stashed_blocks -= sr.size() 339 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 340 sh = self.HashBlocks(self.src, sr) 341 sr = xf.src_ranges.map_within(sr) 342 mapped_stashes.append(sr) 343 if self.version == 2: 344 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 345 else: 346 assert sh in stashes 347 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 348 stashes[sh] -= 1 349 if stashes[sh] == 0: 350 free_string.append("free %s\n" % (sh)) 351 stashes.pop(sh) 352 heapq.heappush(free_stash_ids, sid) 353 354 if unstashed_src_ranges: 355 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 356 if xf.use_stash: 357 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 358 src_str.insert(2, mapped_unstashed.to_string_raw()) 359 mapped_stashes.append(mapped_unstashed) 360 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 361 else: 362 src_str.insert(1, "-") 363 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 364 365 src_str = " ".join(src_str) 366 367 # all versions: 368 # zero <rangeset> 369 # new <rangeset> 370 # erase <rangeset> 371 # 372 # version 1: 373 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 374 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 375 # move <src rangeset> <tgt rangeset> 376 # 377 # version 2: 378 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 379 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 380 # move <tgt rangeset> <src_str> 381 # 382 # version 3: 383 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 384 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 385 # move hash <tgt rangeset> <src_str> 386 387 tgt_size = xf.tgt_ranges.size() 388 389 if xf.style == "new": 390 assert xf.tgt_ranges 391 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 392 total += tgt_size 393 elif xf.style == "move": 394 performs_read = True 395 assert xf.tgt_ranges 396 assert xf.src_ranges.size() == tgt_size 397 if xf.src_ranges != xf.tgt_ranges: 398 if self.version == 1: 399 out.append("%s %s %s\n" % ( 400 xf.style, 401 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 402 elif self.version == 2: 403 out.append("%s %s %s\n" % ( 404 xf.style, 405 xf.tgt_ranges.to_string_raw(), src_str)) 406 elif self.version >= 3: 407 # take into account automatic stashing of overlapping blocks 408 if xf.src_ranges.overlaps(xf.tgt_ranges): 409 temp_stash_usage = stashed_blocks + xf.src_ranges.size(); 410 if temp_stash_usage > max_stashed_blocks: 411 max_stashed_blocks = temp_stash_usage 412 413 out.append("%s %s %s %s\n" % ( 414 xf.style, 415 self.HashBlocks(self.tgt, xf.tgt_ranges), 416 xf.tgt_ranges.to_string_raw(), src_str)) 417 total += tgt_size 418 elif xf.style in ("bsdiff", "imgdiff"): 419 performs_read = True 420 assert xf.tgt_ranges 421 assert xf.src_ranges 422 if self.version == 1: 423 out.append("%s %d %d %s %s\n" % ( 424 xf.style, xf.patch_start, xf.patch_len, 425 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 426 elif self.version == 2: 427 out.append("%s %d %d %s %s\n" % ( 428 xf.style, xf.patch_start, xf.patch_len, 429 xf.tgt_ranges.to_string_raw(), src_str)) 430 elif self.version >= 3: 431 # take into account automatic stashing of overlapping blocks 432 if xf.src_ranges.overlaps(xf.tgt_ranges): 433 temp_stash_usage = stashed_blocks + xf.src_ranges.size(); 434 if temp_stash_usage > max_stashed_blocks: 435 max_stashed_blocks = temp_stash_usage 436 437 out.append("%s %d %d %s %s %s %s\n" % ( 438 xf.style, 439 xf.patch_start, xf.patch_len, 440 self.HashBlocks(self.src, xf.src_ranges), 441 self.HashBlocks(self.tgt, xf.tgt_ranges), 442 xf.tgt_ranges.to_string_raw(), src_str)) 443 total += tgt_size 444 elif xf.style == "zero": 445 assert xf.tgt_ranges 446 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 447 if to_zero: 448 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 449 total += to_zero.size() 450 else: 451 raise ValueError("unknown transfer style '%s'\n" % xf.style) 452 453 if free_string: 454 out.append("".join(free_string)) 455 456 # sanity check: abort if we're going to need more than 512 MB if 457 # stash space 458 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 459 460 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 461 if performs_read: 462 # if some of the original data is used, then at the end we'll 463 # erase all the blocks on the partition that don't contain data 464 # in the new image. 465 new_dontcare = all_tgt.subtract(self.tgt.care_map) 466 if new_dontcare: 467 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 468 else: 469 # if nothing is read (ie, this is a full OTA), then we can start 470 # by erasing the entire partition. 471 out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),)) 472 473 out.insert(0, "%d\n" % (self.version,)) # format version number 474 out.insert(1, str(total) + "\n") 475 if self.version >= 2: 476 # version 2 only: after the total block count, we give the number 477 # of stash slots needed, and the maximum size needed (in blocks) 478 out.insert(2, str(next_stash_id) + "\n") 479 out.insert(3, str(max_stashed_blocks) + "\n") 480 481 with open(prefix + ".transfer.list", "wb") as f: 482 for i in out: 483 f.write(i) 484 485 if self.version >= 2: 486 print("max stashed blocks: %d (%d bytes)\n" % ( 487 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 488 489 def ComputePatches(self, prefix): 490 print("Reticulating splines...") 491 diff_q = [] 492 patch_num = 0 493 with open(prefix + ".new.dat", "wb") as new_f: 494 for xf in self.transfers: 495 if xf.style == "zero": 496 pass 497 elif xf.style == "new": 498 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 499 new_f.write(piece) 500 elif xf.style == "diff": 501 src = self.src.ReadRangeSet(xf.src_ranges) 502 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 503 504 # We can't compare src and tgt directly because they may have 505 # the same content but be broken up into blocks differently, eg: 506 # 507 # ["he", "llo"] vs ["h", "ello"] 508 # 509 # We want those to compare equal, ideally without having to 510 # actually concatenate the strings (these may be tens of 511 # megabytes). 512 513 src_sha1 = sha1() 514 for p in src: 515 src_sha1.update(p) 516 tgt_sha1 = sha1() 517 tgt_size = 0 518 for p in tgt: 519 tgt_sha1.update(p) 520 tgt_size += len(p) 521 522 if src_sha1.digest() == tgt_sha1.digest(): 523 # These are identical; we don't need to generate a patch, 524 # just issue copy commands on the device. 525 xf.style = "move" 526 else: 527 # For files in zip format (eg, APKs, JARs, etc.) we would 528 # like to use imgdiff -z if possible (because it usually 529 # produces significantly smaller patches than bsdiff). 530 # This is permissible if: 531 # 532 # - the source and target files are monotonic (ie, the 533 # data is stored with blocks in increasing order), and 534 # - we haven't removed any blocks from the source set. 535 # 536 # If these conditions are satisfied then appending all the 537 # blocks in the set together in order will produce a valid 538 # zip file (plus possibly extra zeros in the last block), 539 # which is what imgdiff needs to operate. (imgdiff is 540 # fine with extra zeros at the end of the file.) 541 imgdiff = (xf.intact and 542 xf.tgt_name.split(".")[-1].lower() 543 in ("apk", "jar", "zip")) 544 xf.style = "imgdiff" if imgdiff else "bsdiff" 545 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 546 patch_num += 1 547 548 else: 549 assert False, "unknown style " + xf.style 550 551 if diff_q: 552 if self.threads > 1: 553 print("Computing patches (using %d threads)..." % (self.threads,)) 554 else: 555 print("Computing patches...") 556 diff_q.sort() 557 558 patches = [None] * patch_num 559 560 # TODO: Rewrite with multiprocessing.ThreadPool? 561 lock = threading.Lock() 562 def diff_worker(): 563 while True: 564 with lock: 565 if not diff_q: 566 return 567 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 568 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 569 size = len(patch) 570 with lock: 571 patches[patchnum] = (patch, xf) 572 print("%10d %10d (%6.2f%%) %7s %s" % ( 573 size, tgt_size, size * 100.0 / tgt_size, xf.style, 574 xf.tgt_name if xf.tgt_name == xf.src_name else ( 575 xf.tgt_name + " (from " + xf.src_name + ")"))) 576 577 threads = [threading.Thread(target=diff_worker) 578 for _ in range(self.threads)] 579 for th in threads: 580 th.start() 581 while threads: 582 threads.pop().join() 583 else: 584 patches = [] 585 586 p = 0 587 with open(prefix + ".patch.dat", "wb") as patch_f: 588 for patch, xf in patches: 589 xf.patch_start = p 590 xf.patch_len = len(patch) 591 patch_f.write(patch) 592 p += len(patch) 593 594 def AssertSequenceGood(self): 595 # Simulate the sequences of transfers we will output, and check that: 596 # - we never read a block after writing it, and 597 # - we write every block we care about exactly once. 598 599 # Start with no blocks having been touched yet. 600 touched = RangeSet() 601 602 # Imagine processing the transfers in order. 603 for xf in self.transfers: 604 # Check that the input blocks for this transfer haven't yet been touched. 605 606 x = xf.src_ranges 607 if self.version >= 2: 608 for _, sr in xf.use_stash: 609 x = x.subtract(sr) 610 611 assert not touched.overlaps(x) 612 # Check that the output blocks for this transfer haven't yet been touched. 613 assert not touched.overlaps(xf.tgt_ranges) 614 # Touch all the blocks written by this transfer. 615 touched = touched.union(xf.tgt_ranges) 616 617 # Check that we've written every target block. 618 assert touched == self.tgt.care_map 619 620 def ImproveVertexSequence(self): 621 print("Improving vertex order...") 622 623 # At this point our digraph is acyclic; we reversed any edges that 624 # were backwards in the heuristically-generated sequence. The 625 # previously-generated order is still acceptable, but we hope to 626 # find a better order that needs less memory for stashed data. 627 # Now we do a topological sort to generate a new vertex order, 628 # using a greedy algorithm to choose which vertex goes next 629 # whenever we have a choice. 630 631 # Make a copy of the edge set; this copy will get destroyed by the 632 # algorithm. 633 for xf in self.transfers: 634 xf.incoming = xf.goes_after.copy() 635 xf.outgoing = xf.goes_before.copy() 636 637 L = [] # the new vertex order 638 639 # S is the set of sources in the remaining graph; we always choose 640 # the one that leaves the least amount of stashed data after it's 641 # executed. 642 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 643 if not u.incoming] 644 heapq.heapify(S) 645 646 while S: 647 _, _, xf = heapq.heappop(S) 648 L.append(xf) 649 for u in xf.outgoing: 650 del u.incoming[xf] 651 if not u.incoming: 652 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 653 654 # if this fails then our graph had a cycle. 655 assert len(L) == len(self.transfers) 656 657 self.transfers = L 658 for i, xf in enumerate(L): 659 xf.order = i 660 661 def RemoveBackwardEdges(self): 662 print("Removing backward edges...") 663 in_order = 0 664 out_of_order = 0 665 lost_source = 0 666 667 for xf in self.transfers: 668 lost = 0 669 size = xf.src_ranges.size() 670 for u in xf.goes_before: 671 # xf should go before u 672 if xf.order < u.order: 673 # it does, hurray! 674 in_order += 1 675 else: 676 # it doesn't, boo. trim the blocks that u writes from xf's 677 # source, so that xf can go after u. 678 out_of_order += 1 679 assert xf.src_ranges.overlaps(u.tgt_ranges) 680 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 681 xf.intact = False 682 683 if xf.style == "diff" and not xf.src_ranges: 684 # nothing left to diff from; treat as new data 685 xf.style = "new" 686 687 lost = size - xf.src_ranges.size() 688 lost_source += lost 689 690 print((" %d/%d dependencies (%.2f%%) were violated; " 691 "%d source blocks removed.") % 692 (out_of_order, in_order + out_of_order, 693 (out_of_order * 100.0 / (in_order + out_of_order)) 694 if (in_order + out_of_order) else 0.0, 695 lost_source)) 696 697 def ReverseBackwardEdges(self): 698 print("Reversing backward edges...") 699 in_order = 0 700 out_of_order = 0 701 stashes = 0 702 stash_size = 0 703 704 for xf in self.transfers: 705 for u in xf.goes_before.copy(): 706 # xf should go before u 707 if xf.order < u.order: 708 # it does, hurray! 709 in_order += 1 710 else: 711 # it doesn't, boo. modify u to stash the blocks that it 712 # writes that xf wants to read, and then require u to go 713 # before xf. 714 out_of_order += 1 715 716 overlap = xf.src_ranges.intersect(u.tgt_ranges) 717 assert overlap 718 719 u.stash_before.append((stashes, overlap)) 720 xf.use_stash.append((stashes, overlap)) 721 stashes += 1 722 stash_size += overlap.size() 723 724 # reverse the edge direction; now xf must go after u 725 del xf.goes_before[u] 726 del u.goes_after[xf] 727 xf.goes_after[u] = None # value doesn't matter 728 u.goes_before[xf] = None 729 730 print((" %d/%d dependencies (%.2f%%) were violated; " 731 "%d source blocks stashed.") % 732 (out_of_order, in_order + out_of_order, 733 (out_of_order * 100.0 / (in_order + out_of_order)) 734 if (in_order + out_of_order) else 0.0, 735 stash_size)) 736 737 def FindVertexSequence(self): 738 print("Finding vertex sequence...") 739 740 # This is based on "A Fast & Effective Heuristic for the Feedback 741 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 742 # it as starting with the digraph G and moving all the vertices to 743 # be on a horizontal line in some order, trying to minimize the 744 # number of edges that end up pointing to the left. Left-pointing 745 # edges will get removed to turn the digraph into a DAG. In this 746 # case each edge has a weight which is the number of source blocks 747 # we'll lose if that edge is removed; we try to minimize the total 748 # weight rather than just the number of edges. 749 750 # Make a copy of the edge set; this copy will get destroyed by the 751 # algorithm. 752 for xf in self.transfers: 753 xf.incoming = xf.goes_after.copy() 754 xf.outgoing = xf.goes_before.copy() 755 756 # We use an OrderedDict instead of just a set so that the output 757 # is repeatable; otherwise it would depend on the hash values of 758 # the transfer objects. 759 G = OrderedDict() 760 for xf in self.transfers: 761 G[xf] = None 762 s1 = deque() # the left side of the sequence, built from left to right 763 s2 = deque() # the right side of the sequence, built from right to left 764 765 while G: 766 767 # Put all sinks at the end of the sequence. 768 while True: 769 sinks = [u for u in G if not u.outgoing] 770 if not sinks: 771 break 772 for u in sinks: 773 s2.appendleft(u) 774 del G[u] 775 for iu in u.incoming: 776 del iu.outgoing[u] 777 778 # Put all the sources at the beginning of the sequence. 779 while True: 780 sources = [u for u in G if not u.incoming] 781 if not sources: 782 break 783 for u in sources: 784 s1.append(u) 785 del G[u] 786 for iu in u.outgoing: 787 del iu.incoming[u] 788 789 if not G: 790 break 791 792 # Find the "best" vertex to put next. "Best" is the one that 793 # maximizes the net difference in source blocks saved we get by 794 # pretending it's a source rather than a sink. 795 796 max_d = None 797 best_u = None 798 for u in G: 799 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 800 if best_u is None or d > max_d: 801 max_d = d 802 best_u = u 803 804 u = best_u 805 s1.append(u) 806 del G[u] 807 for iu in u.outgoing: 808 del iu.incoming[u] 809 for iu in u.incoming: 810 del iu.outgoing[u] 811 812 # Now record the sequence in the 'order' field of each transfer, 813 # and by rearranging self.transfers to be in the chosen sequence. 814 815 new_transfers = [] 816 for x in itertools.chain(s1, s2): 817 x.order = len(new_transfers) 818 new_transfers.append(x) 819 del x.incoming 820 del x.outgoing 821 822 self.transfers = new_transfers 823 824 def GenerateDigraph(self): 825 print("Generating digraph...") 826 for a in self.transfers: 827 for b in self.transfers: 828 if a is b: 829 continue 830 831 # If the blocks written by A are read by B, then B needs to go before A. 832 i = a.tgt_ranges.intersect(b.src_ranges) 833 if i: 834 if b.src_name == "__ZERO": 835 # the cost of removing source blocks for the __ZERO domain 836 # is (nearly) zero. 837 size = 0 838 else: 839 size = i.size() 840 b.goes_before[a] = size 841 a.goes_after[b] = size 842 843 def FindTransfers(self): 844 empty = RangeSet() 845 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 846 if tgt_fn == "__ZERO": 847 # the special "__ZERO" domain is all the blocks not contained 848 # in any file and that are filled with zeros. We have a 849 # special transfer style for zero blocks. 850 src_ranges = self.src.file_map.get("__ZERO", empty) 851 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 852 "zero", self.transfers) 853 continue 854 855 elif tgt_fn == "__COPY": 856 # "__COPY" domain includes all the blocks not contained in any 857 # file and that need to be copied unconditionally to the target. 858 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 859 continue 860 861 elif tgt_fn in self.src.file_map: 862 # Look for an exact pathname match in the source. 863 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 864 "diff", self.transfers) 865 continue 866 867 b = os.path.basename(tgt_fn) 868 if b in self.src_basenames: 869 # Look for an exact basename match in the source. 870 src_fn = self.src_basenames[b] 871 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 872 "diff", self.transfers) 873 continue 874 875 b = re.sub("[0-9]+", "#", b) 876 if b in self.src_numpatterns: 877 # Look for a 'number pattern' match (a basename match after 878 # all runs of digits are replaced by "#"). (This is useful 879 # for .so files that contain version numbers in the filename 880 # that get bumped.) 881 src_fn = self.src_numpatterns[b] 882 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 883 "diff", self.transfers) 884 continue 885 886 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 887 888 def AbbreviateSourceNames(self): 889 for k in self.src.file_map.keys(): 890 b = os.path.basename(k) 891 self.src_basenames[b] = k 892 b = re.sub("[0-9]+", "#", b) 893 self.src_numpatterns[b] = k 894 895 @staticmethod 896 def AssertPartition(total, seq): 897 """Assert that all the RangeSets in 'seq' form a partition of the 898 'total' RangeSet (ie, they are nonintersecting and their union 899 equals 'total').""" 900 so_far = RangeSet() 901 for i in seq: 902 assert not so_far.overlaps(i) 903 so_far = so_far.union(i) 904 assert so_far == total 905