blockimgdiff.py revision 29f529f33e5288efd8823cdc935260cfb0d8c7fa
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import re 24import subprocess 25import threading 26import tempfile 27 28from rangelib import RangeSet 29 30 31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 32 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72 73class Image(object): 74 def ReadRangeSet(self, ranges): 75 raise NotImplementedError 76 77 def TotalSha1(self): 78 raise NotImplementedError 79 80 81class EmptyImage(Image): 82 """A zero-length image.""" 83 blocksize = 4096 84 care_map = RangeSet() 85 total_blocks = 0 86 file_map = {} 87 def ReadRangeSet(self, ranges): 88 return () 89 def TotalSha1(self): 90 return sha1().hexdigest() 91 92 93class DataImage(Image): 94 """An image wrapped around a single string of data.""" 95 96 def __init__(self, data, trim=False, pad=False): 97 self.data = data 98 self.blocksize = 4096 99 100 assert not (trim and pad) 101 102 partial = len(self.data) % self.blocksize 103 if partial > 0: 104 if trim: 105 self.data = self.data[:-partial] 106 elif pad: 107 self.data += '\0' * (self.blocksize - partial) 108 else: 109 raise ValueError(("data for DataImage must be multiple of %d bytes " 110 "unless trim or pad is specified") % 111 (self.blocksize,)) 112 113 assert len(self.data) % self.blocksize == 0 114 115 self.total_blocks = len(self.data) / self.blocksize 116 self.care_map = RangeSet(data=(0, self.total_blocks)) 117 118 zero_blocks = [] 119 nonzero_blocks = [] 120 reference = '\0' * self.blocksize 121 122 for i in range(self.total_blocks): 123 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 124 if d == reference: 125 zero_blocks.append(i) 126 zero_blocks.append(i+1) 127 else: 128 nonzero_blocks.append(i) 129 nonzero_blocks.append(i+1) 130 131 self.file_map = {"__ZERO": RangeSet(zero_blocks), 132 "__NONZERO": RangeSet(nonzero_blocks)} 133 134 def ReadRangeSet(self, ranges): 135 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 136 137 def TotalSha1(self): 138 return sha1(self.data).hexdigest() 139 140 141class Transfer(object): 142 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 143 self.tgt_name = tgt_name 144 self.src_name = src_name 145 self.tgt_ranges = tgt_ranges 146 self.src_ranges = src_ranges 147 self.style = style 148 self.intact = (getattr(tgt_ranges, "monotonic", False) and 149 getattr(src_ranges, "monotonic", False)) 150 151 # We use OrderedDict rather than dict so that the output is repeatable; 152 # otherwise it would depend on the hash values of the Transfer objects. 153 self.goes_before = OrderedDict() 154 self.goes_after = OrderedDict() 155 156 self.stash_before = [] 157 self.use_stash = [] 158 159 self.id = len(by_id) 160 by_id.append(self) 161 162 def NetStashChange(self): 163 return (sum(sr.size() for (_, sr) in self.stash_before) - 164 sum(sr.size() for (_, sr) in self.use_stash)) 165 166 def __str__(self): 167 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 168 " to " + str(self.tgt_ranges) + ">") 169 170 171# BlockImageDiff works on two image objects. An image object is 172# anything that provides the following attributes: 173# 174# blocksize: the size in bytes of a block, currently must be 4096. 175# 176# total_blocks: the total size of the partition/image, in blocks. 177# 178# care_map: a RangeSet containing which blocks (in the range [0, 179# total_blocks) we actually care about; i.e. which blocks contain 180# data. 181# 182# file_map: a dict that partitions the blocks contained in care_map 183# into smaller domains that are useful for doing diffs on. 184# (Typically a domain is a file, and the key in file_map is the 185# pathname.) 186# 187# ReadRangeSet(): a function that takes a RangeSet and returns the 188# data contained in the image blocks of that RangeSet. The data 189# is returned as a list or tuple of strings; concatenating the 190# elements together should produce the requested data. 191# Implementations are free to break up the data into list/tuple 192# elements in any way that is convenient. 193# 194# TotalSha1(): a function that returns (as a hex string) the SHA-1 195# hash of all the data in the image (ie, all the blocks in the 196# care_map) 197# 198# When creating a BlockImageDiff, the src image may be None, in which 199# case the list of transfers produced will never read from the 200# original image. 201 202class BlockImageDiff(object): 203 def __init__(self, tgt, src=None, threads=None, version=3): 204 if threads is None: 205 threads = multiprocessing.cpu_count() // 2 206 if threads == 0: 207 threads = 1 208 self.threads = threads 209 self.version = version 210 self.transfers = [] 211 self.src_basenames = {} 212 self.src_numpatterns = {} 213 214 assert version in (1, 2, 3) 215 216 self.tgt = tgt 217 if src is None: 218 src = EmptyImage() 219 self.src = src 220 221 # The updater code that installs the patch always uses 4k blocks. 222 assert tgt.blocksize == 4096 223 assert src.blocksize == 4096 224 225 # The range sets in each filemap should comprise a partition of 226 # the care map. 227 self.AssertPartition(src.care_map, src.file_map.values()) 228 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 229 230 def Compute(self, prefix): 231 # When looking for a source file to use as the diff input for a 232 # target file, we try: 233 # 1) an exact path match if available, otherwise 234 # 2) a exact basename match if available, otherwise 235 # 3) a basename match after all runs of digits are replaced by 236 # "#" if available, otherwise 237 # 4) we have no source for this target. 238 self.AbbreviateSourceNames() 239 self.FindTransfers() 240 241 # Find the ordering dependencies among transfers (this is O(n^2) 242 # in the number of transfers). 243 self.GenerateDigraph() 244 # Find a sequence of transfers that satisfies as many ordering 245 # dependencies as possible (heuristically). 246 self.FindVertexSequence() 247 # Fix up the ordering dependencies that the sequence didn't 248 # satisfy. 249 if self.version == 1: 250 self.RemoveBackwardEdges() 251 else: 252 self.ReverseBackwardEdges() 253 self.ImproveVertexSequence() 254 255 # Double-check our work. 256 self.AssertSequenceGood() 257 258 self.ComputePatches(prefix) 259 self.WriteTransfers(prefix) 260 261 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 262 data = source.ReadRangeSet(ranges) 263 ctx = sha1() 264 265 for p in data: 266 ctx.update(p) 267 268 return ctx.hexdigest() 269 270 def WriteTransfers(self, prefix): 271 out = [] 272 273 total = 0 274 performs_read = False 275 276 stashes = {} 277 stashed_blocks = 0 278 max_stashed_blocks = 0 279 280 free_stash_ids = [] 281 next_stash_id = 0 282 283 for xf in self.transfers: 284 285 if self.version < 2: 286 assert not xf.stash_before 287 assert not xf.use_stash 288 289 for s, sr in xf.stash_before: 290 assert s not in stashes 291 if free_stash_ids: 292 sid = heapq.heappop(free_stash_ids) 293 else: 294 sid = next_stash_id 295 next_stash_id += 1 296 stashes[s] = sid 297 stashed_blocks += sr.size() 298 if self.version == 2: 299 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 300 else: 301 sh = self.HashBlocks(self.src, sr) 302 if sh in stashes: 303 stashes[sh] += 1 304 else: 305 stashes[sh] = 1 306 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 307 308 if stashed_blocks > max_stashed_blocks: 309 max_stashed_blocks = stashed_blocks 310 311 free_string = [] 312 313 if self.version == 1: 314 src_str = xf.src_ranges.to_string_raw() 315 elif self.version >= 2: 316 317 # <# blocks> <src ranges> 318 # OR 319 # <# blocks> <src ranges> <src locs> <stash refs...> 320 # OR 321 # <# blocks> - <stash refs...> 322 323 size = xf.src_ranges.size() 324 src_str = [str(size)] 325 326 unstashed_src_ranges = xf.src_ranges 327 mapped_stashes = [] 328 for s, sr in xf.use_stash: 329 sid = stashes.pop(s) 330 stashed_blocks -= sr.size() 331 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 332 sh = self.HashBlocks(self.src, sr) 333 sr = xf.src_ranges.map_within(sr) 334 mapped_stashes.append(sr) 335 if self.version == 2: 336 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 337 else: 338 assert sh in stashes 339 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 340 stashes[sh] -= 1 341 if stashes[sh] == 0: 342 free_string.append("free %s\n" % (sh)) 343 stashes.pop(sh) 344 heapq.heappush(free_stash_ids, sid) 345 346 if unstashed_src_ranges: 347 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 348 if xf.use_stash: 349 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 350 src_str.insert(2, mapped_unstashed.to_string_raw()) 351 mapped_stashes.append(mapped_unstashed) 352 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 353 else: 354 src_str.insert(1, "-") 355 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 356 357 src_str = " ".join(src_str) 358 359 # all versions: 360 # zero <rangeset> 361 # new <rangeset> 362 # erase <rangeset> 363 # 364 # version 1: 365 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 366 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 367 # move <src rangeset> <tgt rangeset> 368 # 369 # version 2: 370 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 371 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 372 # move <tgt rangeset> <src_str> 373 # 374 # version 3: 375 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 376 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 377 # move hash <tgt rangeset> <src_str> 378 379 tgt_size = xf.tgt_ranges.size() 380 381 if xf.style == "new": 382 assert xf.tgt_ranges 383 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 384 total += tgt_size 385 elif xf.style == "move": 386 performs_read = True 387 assert xf.tgt_ranges 388 assert xf.src_ranges.size() == tgt_size 389 if xf.src_ranges != xf.tgt_ranges: 390 if self.version == 1: 391 out.append("%s %s %s\n" % ( 392 xf.style, 393 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 394 elif self.version == 2: 395 out.append("%s %s %s\n" % ( 396 xf.style, 397 xf.tgt_ranges.to_string_raw(), src_str)) 398 elif self.version >= 3: 399 # take into account automatic stashing of overlapping blocks 400 if xf.src_ranges.overlaps(xf.tgt_ranges): 401 temp_stash_usage = stashed_blocks + xf.src_ranges.size(); 402 if temp_stash_usage > max_stashed_blocks: 403 max_stashed_blocks = temp_stash_usage 404 405 out.append("%s %s %s %s\n" % ( 406 xf.style, 407 self.HashBlocks(self.tgt, xf.tgt_ranges), 408 xf.tgt_ranges.to_string_raw(), src_str)) 409 total += tgt_size 410 elif xf.style in ("bsdiff", "imgdiff"): 411 performs_read = True 412 assert xf.tgt_ranges 413 assert xf.src_ranges 414 if self.version == 1: 415 out.append("%s %d %d %s %s\n" % ( 416 xf.style, xf.patch_start, xf.patch_len, 417 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 418 elif self.version == 2: 419 out.append("%s %d %d %s %s\n" % ( 420 xf.style, xf.patch_start, xf.patch_len, 421 xf.tgt_ranges.to_string_raw(), src_str)) 422 elif self.version >= 3: 423 # take into account automatic stashing of overlapping blocks 424 if xf.src_ranges.overlaps(xf.tgt_ranges): 425 temp_stash_usage = stashed_blocks + xf.src_ranges.size(); 426 if temp_stash_usage > max_stashed_blocks: 427 max_stashed_blocks = temp_stash_usage 428 429 out.append("%s %d %d %s %s %s %s\n" % ( 430 xf.style, 431 xf.patch_start, xf.patch_len, 432 self.HashBlocks(self.src, xf.src_ranges), 433 self.HashBlocks(self.tgt, xf.tgt_ranges), 434 xf.tgt_ranges.to_string_raw(), src_str)) 435 total += tgt_size 436 elif xf.style == "zero": 437 assert xf.tgt_ranges 438 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 439 if to_zero: 440 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 441 total += to_zero.size() 442 else: 443 raise ValueError("unknown transfer style '%s'\n" % xf.style) 444 445 if free_string: 446 out.append("".join(free_string)) 447 448 449 # sanity check: abort if we're going to need more than 512 MB if 450 # stash space 451 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 452 453 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 454 if performs_read: 455 # if some of the original data is used, then at the end we'll 456 # erase all the blocks on the partition that don't contain data 457 # in the new image. 458 new_dontcare = all_tgt.subtract(self.tgt.care_map) 459 if new_dontcare: 460 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 461 else: 462 # if nothing is read (ie, this is a full OTA), then we can start 463 # by erasing the entire partition. 464 out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),)) 465 466 out.insert(0, "%d\n" % (self.version,)) # format version number 467 out.insert(1, str(total) + "\n") 468 if self.version >= 2: 469 # version 2 only: after the total block count, we give the number 470 # of stash slots needed, and the maximum size needed (in blocks) 471 out.insert(2, str(next_stash_id) + "\n") 472 out.insert(3, str(max_stashed_blocks) + "\n") 473 474 with open(prefix + ".transfer.list", "wb") as f: 475 for i in out: 476 f.write(i) 477 478 if self.version >= 2: 479 print("max stashed blocks: %d (%d bytes)\n" % ( 480 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 481 482 def ComputePatches(self, prefix): 483 print("Reticulating splines...") 484 diff_q = [] 485 patch_num = 0 486 with open(prefix + ".new.dat", "wb") as new_f: 487 for xf in self.transfers: 488 if xf.style == "zero": 489 pass 490 elif xf.style == "new": 491 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 492 new_f.write(piece) 493 elif xf.style == "diff": 494 src = self.src.ReadRangeSet(xf.src_ranges) 495 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 496 497 # We can't compare src and tgt directly because they may have 498 # the same content but be broken up into blocks differently, eg: 499 # 500 # ["he", "llo"] vs ["h", "ello"] 501 # 502 # We want those to compare equal, ideally without having to 503 # actually concatenate the strings (these may be tens of 504 # megabytes). 505 506 src_sha1 = sha1() 507 for p in src: 508 src_sha1.update(p) 509 tgt_sha1 = sha1() 510 tgt_size = 0 511 for p in tgt: 512 tgt_sha1.update(p) 513 tgt_size += len(p) 514 515 if src_sha1.digest() == tgt_sha1.digest(): 516 # These are identical; we don't need to generate a patch, 517 # just issue copy commands on the device. 518 xf.style = "move" 519 else: 520 # For files in zip format (eg, APKs, JARs, etc.) we would 521 # like to use imgdiff -z if possible (because it usually 522 # produces significantly smaller patches than bsdiff). 523 # This is permissible if: 524 # 525 # - the source and target files are monotonic (ie, the 526 # data is stored with blocks in increasing order), and 527 # - we haven't removed any blocks from the source set. 528 # 529 # If these conditions are satisfied then appending all the 530 # blocks in the set together in order will produce a valid 531 # zip file (plus possibly extra zeros in the last block), 532 # which is what imgdiff needs to operate. (imgdiff is 533 # fine with extra zeros at the end of the file.) 534 imgdiff = (xf.intact and 535 xf.tgt_name.split(".")[-1].lower() 536 in ("apk", "jar", "zip")) 537 xf.style = "imgdiff" if imgdiff else "bsdiff" 538 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 539 patch_num += 1 540 541 else: 542 assert False, "unknown style " + xf.style 543 544 if diff_q: 545 if self.threads > 1: 546 print("Computing patches (using %d threads)..." % (self.threads,)) 547 else: 548 print("Computing patches...") 549 diff_q.sort() 550 551 patches = [None] * patch_num 552 553 # TODO: Rewrite with multiprocessing.ThreadPool? 554 lock = threading.Lock() 555 def diff_worker(): 556 while True: 557 with lock: 558 if not diff_q: 559 return 560 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 561 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 562 size = len(patch) 563 with lock: 564 patches[patchnum] = (patch, xf) 565 print("%10d %10d (%6.2f%%) %7s %s" % ( 566 size, tgt_size, size * 100.0 / tgt_size, xf.style, 567 xf.tgt_name if xf.tgt_name == xf.src_name else ( 568 xf.tgt_name + " (from " + xf.src_name + ")"))) 569 570 threads = [threading.Thread(target=diff_worker) 571 for _ in range(self.threads)] 572 for th in threads: 573 th.start() 574 while threads: 575 threads.pop().join() 576 else: 577 patches = [] 578 579 p = 0 580 with open(prefix + ".patch.dat", "wb") as patch_f: 581 for patch, xf in patches: 582 xf.patch_start = p 583 xf.patch_len = len(patch) 584 patch_f.write(patch) 585 p += len(patch) 586 587 def AssertSequenceGood(self): 588 # Simulate the sequences of transfers we will output, and check that: 589 # - we never read a block after writing it, and 590 # - we write every block we care about exactly once. 591 592 # Start with no blocks having been touched yet. 593 touched = RangeSet() 594 595 # Imagine processing the transfers in order. 596 for xf in self.transfers: 597 # Check that the input blocks for this transfer haven't yet been touched. 598 599 x = xf.src_ranges 600 if self.version >= 2: 601 for _, sr in xf.use_stash: 602 x = x.subtract(sr) 603 604 assert not touched.overlaps(x) 605 # Check that the output blocks for this transfer haven't yet been touched. 606 assert not touched.overlaps(xf.tgt_ranges) 607 # Touch all the blocks written by this transfer. 608 touched = touched.union(xf.tgt_ranges) 609 610 # Check that we've written every target block. 611 assert touched == self.tgt.care_map 612 613 def ImproveVertexSequence(self): 614 print("Improving vertex order...") 615 616 # At this point our digraph is acyclic; we reversed any edges that 617 # were backwards in the heuristically-generated sequence. The 618 # previously-generated order is still acceptable, but we hope to 619 # find a better order that needs less memory for stashed data. 620 # Now we do a topological sort to generate a new vertex order, 621 # using a greedy algorithm to choose which vertex goes next 622 # whenever we have a choice. 623 624 # Make a copy of the edge set; this copy will get destroyed by the 625 # algorithm. 626 for xf in self.transfers: 627 xf.incoming = xf.goes_after.copy() 628 xf.outgoing = xf.goes_before.copy() 629 630 L = [] # the new vertex order 631 632 # S is the set of sources in the remaining graph; we always choose 633 # the one that leaves the least amount of stashed data after it's 634 # executed. 635 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 636 if not u.incoming] 637 heapq.heapify(S) 638 639 while S: 640 _, _, xf = heapq.heappop(S) 641 L.append(xf) 642 for u in xf.outgoing: 643 del u.incoming[xf] 644 if not u.incoming: 645 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 646 647 # if this fails then our graph had a cycle. 648 assert len(L) == len(self.transfers) 649 650 self.transfers = L 651 for i, xf in enumerate(L): 652 xf.order = i 653 654 def RemoveBackwardEdges(self): 655 print("Removing backward edges...") 656 in_order = 0 657 out_of_order = 0 658 lost_source = 0 659 660 for xf in self.transfers: 661 lost = 0 662 size = xf.src_ranges.size() 663 for u in xf.goes_before: 664 # xf should go before u 665 if xf.order < u.order: 666 # it does, hurray! 667 in_order += 1 668 else: 669 # it doesn't, boo. trim the blocks that u writes from xf's 670 # source, so that xf can go after u. 671 out_of_order += 1 672 assert xf.src_ranges.overlaps(u.tgt_ranges) 673 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 674 xf.intact = False 675 676 if xf.style == "diff" and not xf.src_ranges: 677 # nothing left to diff from; treat as new data 678 xf.style = "new" 679 680 lost = size - xf.src_ranges.size() 681 lost_source += lost 682 683 print((" %d/%d dependencies (%.2f%%) were violated; " 684 "%d source blocks removed.") % 685 (out_of_order, in_order + out_of_order, 686 (out_of_order * 100.0 / (in_order + out_of_order)) 687 if (in_order + out_of_order) else 0.0, 688 lost_source)) 689 690 def ReverseBackwardEdges(self): 691 print("Reversing backward edges...") 692 in_order = 0 693 out_of_order = 0 694 stashes = 0 695 stash_size = 0 696 697 for xf in self.transfers: 698 for u in xf.goes_before.copy(): 699 # xf should go before u 700 if xf.order < u.order: 701 # it does, hurray! 702 in_order += 1 703 else: 704 # it doesn't, boo. modify u to stash the blocks that it 705 # writes that xf wants to read, and then require u to go 706 # before xf. 707 out_of_order += 1 708 709 overlap = xf.src_ranges.intersect(u.tgt_ranges) 710 assert overlap 711 712 u.stash_before.append((stashes, overlap)) 713 xf.use_stash.append((stashes, overlap)) 714 stashes += 1 715 stash_size += overlap.size() 716 717 # reverse the edge direction; now xf must go after u 718 del xf.goes_before[u] 719 del u.goes_after[xf] 720 xf.goes_after[u] = None # value doesn't matter 721 u.goes_before[xf] = None 722 723 print((" %d/%d dependencies (%.2f%%) were violated; " 724 "%d source blocks stashed.") % 725 (out_of_order, in_order + out_of_order, 726 (out_of_order * 100.0 / (in_order + out_of_order)) 727 if (in_order + out_of_order) else 0.0, 728 stash_size)) 729 730 def FindVertexSequence(self): 731 print("Finding vertex sequence...") 732 733 # This is based on "A Fast & Effective Heuristic for the Feedback 734 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 735 # it as starting with the digraph G and moving all the vertices to 736 # be on a horizontal line in some order, trying to minimize the 737 # number of edges that end up pointing to the left. Left-pointing 738 # edges will get removed to turn the digraph into a DAG. In this 739 # case each edge has a weight which is the number of source blocks 740 # we'll lose if that edge is removed; we try to minimize the total 741 # weight rather than just the number of edges. 742 743 # Make a copy of the edge set; this copy will get destroyed by the 744 # algorithm. 745 for xf in self.transfers: 746 xf.incoming = xf.goes_after.copy() 747 xf.outgoing = xf.goes_before.copy() 748 749 # We use an OrderedDict instead of just a set so that the output 750 # is repeatable; otherwise it would depend on the hash values of 751 # the transfer objects. 752 G = OrderedDict() 753 for xf in self.transfers: 754 G[xf] = None 755 s1 = deque() # the left side of the sequence, built from left to right 756 s2 = deque() # the right side of the sequence, built from right to left 757 758 while G: 759 760 # Put all sinks at the end of the sequence. 761 while True: 762 sinks = [u for u in G if not u.outgoing] 763 if not sinks: 764 break 765 for u in sinks: 766 s2.appendleft(u) 767 del G[u] 768 for iu in u.incoming: 769 del iu.outgoing[u] 770 771 # Put all the sources at the beginning of the sequence. 772 while True: 773 sources = [u for u in G if not u.incoming] 774 if not sources: 775 break 776 for u in sources: 777 s1.append(u) 778 del G[u] 779 for iu in u.outgoing: 780 del iu.incoming[u] 781 782 if not G: 783 break 784 785 # Find the "best" vertex to put next. "Best" is the one that 786 # maximizes the net difference in source blocks saved we get by 787 # pretending it's a source rather than a sink. 788 789 max_d = None 790 best_u = None 791 for u in G: 792 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 793 if best_u is None or d > max_d: 794 max_d = d 795 best_u = u 796 797 u = best_u 798 s1.append(u) 799 del G[u] 800 for iu in u.outgoing: 801 del iu.incoming[u] 802 for iu in u.incoming: 803 del iu.outgoing[u] 804 805 # Now record the sequence in the 'order' field of each transfer, 806 # and by rearranging self.transfers to be in the chosen sequence. 807 808 new_transfers = [] 809 for x in itertools.chain(s1, s2): 810 x.order = len(new_transfers) 811 new_transfers.append(x) 812 del x.incoming 813 del x.outgoing 814 815 self.transfers = new_transfers 816 817 def GenerateDigraph(self): 818 print("Generating digraph...") 819 for a in self.transfers: 820 for b in self.transfers: 821 if a is b: 822 continue 823 824 # If the blocks written by A are read by B, then B needs to go before A. 825 i = a.tgt_ranges.intersect(b.src_ranges) 826 if i: 827 if b.src_name == "__ZERO": 828 # the cost of removing source blocks for the __ZERO domain 829 # is (nearly) zero. 830 size = 0 831 else: 832 size = i.size() 833 b.goes_before[a] = size 834 a.goes_after[b] = size 835 836 def FindTransfers(self): 837 empty = RangeSet() 838 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 839 if tgt_fn == "__ZERO": 840 # the special "__ZERO" domain is all the blocks not contained 841 # in any file and that are filled with zeros. We have a 842 # special transfer style for zero blocks. 843 src_ranges = self.src.file_map.get("__ZERO", empty) 844 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 845 "zero", self.transfers) 846 continue 847 848 elif tgt_fn in self.src.file_map: 849 # Look for an exact pathname match in the source. 850 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 851 "diff", self.transfers) 852 continue 853 854 b = os.path.basename(tgt_fn) 855 if b in self.src_basenames: 856 # Look for an exact basename match in the source. 857 src_fn = self.src_basenames[b] 858 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 859 "diff", self.transfers) 860 continue 861 862 b = re.sub("[0-9]+", "#", b) 863 if b in self.src_numpatterns: 864 # Look for a 'number pattern' match (a basename match after 865 # all runs of digits are replaced by "#"). (This is useful 866 # for .so files that contain version numbers in the filename 867 # that get bumped.) 868 src_fn = self.src_numpatterns[b] 869 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 870 "diff", self.transfers) 871 continue 872 873 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 874 875 def AbbreviateSourceNames(self): 876 for k in self.src.file_map.keys(): 877 b = os.path.basename(k) 878 self.src_basenames[b] = k 879 b = re.sub("[0-9]+", "#", b) 880 self.src_numpatterns[b] = k 881 882 @staticmethod 883 def AssertPartition(total, seq): 884 """Assert that all the RangeSets in 'seq' form a partition of the 885 'total' RangeSet (ie, they are nonintersecting and their union 886 equals 'total').""" 887 so_far = RangeSet() 888 for i in seq: 889 assert not so_far.overlaps(i) 890 so_far = so_far.union(i) 891 assert so_far == total 892