blockimgdiff.py revision 575d68a48edc90d655509f2980dacc69958948de
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import common 20import heapq 21import itertools 22import multiprocessing 23import os 24import re 25import subprocess 26import threading 27import tempfile 28 29from rangelib import RangeSet 30 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34 35def compute_patch(src, tgt, imgdiff=False): 36 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 37 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 38 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 39 os.close(patchfd) 40 41 try: 42 with os.fdopen(srcfd, "wb") as f_src: 43 for p in src: 44 f_src.write(p) 45 46 with os.fdopen(tgtfd, "wb") as f_tgt: 47 for p in tgt: 48 f_tgt.write(p) 49 try: 50 os.unlink(patchfile) 51 except OSError: 52 pass 53 if imgdiff: 54 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 55 stdout=open("/dev/null", "a"), 56 stderr=subprocess.STDOUT) 57 else: 58 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 59 60 if p: 61 raise ValueError("diff failed: " + str(p)) 62 63 with open(patchfile, "rb") as f: 64 return f.read() 65 finally: 66 try: 67 os.unlink(srcfile) 68 os.unlink(tgtfile) 69 os.unlink(patchfile) 70 except OSError: 71 pass 72 73 74class Image(object): 75 def ReadRangeSet(self, ranges): 76 raise NotImplementedError 77 78 def TotalSha1(self, include_clobbered_blocks=False): 79 raise NotImplementedError 80 81 82class EmptyImage(Image): 83 """A zero-length image.""" 84 blocksize = 4096 85 care_map = RangeSet() 86 clobbered_blocks = RangeSet() 87 extended = RangeSet() 88 total_blocks = 0 89 file_map = {} 90 def ReadRangeSet(self, ranges): 91 return () 92 def TotalSha1(self, include_clobbered_blocks=False): 93 # EmptyImage always carries empty clobbered_blocks, so 94 # include_clobbered_blocks can be ignored. 95 assert self.clobbered_blocks.size() == 0 96 return sha1().hexdigest() 97 98 99class DataImage(Image): 100 """An image wrapped around a single string of data.""" 101 102 def __init__(self, data, trim=False, pad=False): 103 self.data = data 104 self.blocksize = 4096 105 106 assert not (trim and pad) 107 108 partial = len(self.data) % self.blocksize 109 if partial > 0: 110 if trim: 111 self.data = self.data[:-partial] 112 elif pad: 113 self.data += '\0' * (self.blocksize - partial) 114 else: 115 raise ValueError(("data for DataImage must be multiple of %d bytes " 116 "unless trim or pad is specified") % 117 (self.blocksize,)) 118 119 assert len(self.data) % self.blocksize == 0 120 121 self.total_blocks = len(self.data) / self.blocksize 122 self.care_map = RangeSet(data=(0, self.total_blocks)) 123 self.clobbered_blocks = RangeSet() 124 self.extended = RangeSet() 125 126 zero_blocks = [] 127 nonzero_blocks = [] 128 reference = '\0' * self.blocksize 129 130 for i in range(self.total_blocks): 131 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 132 if d == reference: 133 zero_blocks.append(i) 134 zero_blocks.append(i+1) 135 else: 136 nonzero_blocks.append(i) 137 nonzero_blocks.append(i+1) 138 139 self.file_map = {"__ZERO": RangeSet(zero_blocks), 140 "__NONZERO": RangeSet(nonzero_blocks)} 141 142 def ReadRangeSet(self, ranges): 143 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 144 145 def TotalSha1(self, include_clobbered_blocks=False): 146 # DataImage always carries empty clobbered_blocks, so 147 # include_clobbered_blocks can be ignored. 148 assert self.clobbered_blocks.size() == 0 149 return sha1(self.data).hexdigest() 150 151 152class Transfer(object): 153 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 154 self.tgt_name = tgt_name 155 self.src_name = src_name 156 self.tgt_ranges = tgt_ranges 157 self.src_ranges = src_ranges 158 self.style = style 159 self.intact = (getattr(tgt_ranges, "monotonic", False) and 160 getattr(src_ranges, "monotonic", False)) 161 162 # We use OrderedDict rather than dict so that the output is repeatable; 163 # otherwise it would depend on the hash values of the Transfer objects. 164 self.goes_before = OrderedDict() 165 self.goes_after = OrderedDict() 166 167 self.stash_before = [] 168 self.use_stash = [] 169 170 self.id = len(by_id) 171 by_id.append(self) 172 173 def NetStashChange(self): 174 return (sum(sr.size() for (_, sr) in self.stash_before) - 175 sum(sr.size() for (_, sr) in self.use_stash)) 176 177 def __str__(self): 178 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 179 " to " + str(self.tgt_ranges) + ">") 180 181 182# BlockImageDiff works on two image objects. An image object is 183# anything that provides the following attributes: 184# 185# blocksize: the size in bytes of a block, currently must be 4096. 186# 187# total_blocks: the total size of the partition/image, in blocks. 188# 189# care_map: a RangeSet containing which blocks (in the range [0, 190# total_blocks) we actually care about; i.e. which blocks contain 191# data. 192# 193# file_map: a dict that partitions the blocks contained in care_map 194# into smaller domains that are useful for doing diffs on. 195# (Typically a domain is a file, and the key in file_map is the 196# pathname.) 197# 198# clobbered_blocks: a RangeSet containing which blocks contain data 199# but may be altered by the FS. They need to be excluded when 200# verifying the partition integrity. 201# 202# ReadRangeSet(): a function that takes a RangeSet and returns the 203# data contained in the image blocks of that RangeSet. The data 204# is returned as a list or tuple of strings; concatenating the 205# elements together should produce the requested data. 206# Implementations are free to break up the data into list/tuple 207# elements in any way that is convenient. 208# 209# TotalSha1(): a function that returns (as a hex string) the SHA-1 210# hash of all the data in the image (ie, all the blocks in the 211# care_map minus clobbered_blocks, or including the clobbered 212# blocks if include_clobbered_blocks is True). 213# 214# When creating a BlockImageDiff, the src image may be None, in which 215# case the list of transfers produced will never read from the 216# original image. 217 218class BlockImageDiff(object): 219 def __init__(self, tgt, src=None, threads=None, version=3): 220 if threads is None: 221 threads = multiprocessing.cpu_count() // 2 222 if threads == 0: 223 threads = 1 224 self.threads = threads 225 self.version = version 226 self.transfers = [] 227 self.src_basenames = {} 228 self.src_numpatterns = {} 229 230 assert version in (1, 2, 3) 231 232 self.tgt = tgt 233 if src is None: 234 src = EmptyImage() 235 self.src = src 236 237 # The updater code that installs the patch always uses 4k blocks. 238 assert tgt.blocksize == 4096 239 assert src.blocksize == 4096 240 241 # The range sets in each filemap should comprise a partition of 242 # the care map. 243 self.AssertPartition(src.care_map, src.file_map.values()) 244 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 245 246 def Compute(self, prefix): 247 # When looking for a source file to use as the diff input for a 248 # target file, we try: 249 # 1) an exact path match if available, otherwise 250 # 2) a exact basename match if available, otherwise 251 # 3) a basename match after all runs of digits are replaced by 252 # "#" if available, otherwise 253 # 4) we have no source for this target. 254 self.AbbreviateSourceNames() 255 self.FindTransfers() 256 257 # Find the ordering dependencies among transfers (this is O(n^2) 258 # in the number of transfers). 259 self.GenerateDigraph() 260 # Find a sequence of transfers that satisfies as many ordering 261 # dependencies as possible (heuristically). 262 self.FindVertexSequence() 263 # Fix up the ordering dependencies that the sequence didn't 264 # satisfy. 265 if self.version == 1: 266 self.RemoveBackwardEdges() 267 else: 268 self.ReverseBackwardEdges() 269 self.ImproveVertexSequence() 270 271 # Double-check our work. 272 self.AssertSequenceGood() 273 274 self.ComputePatches(prefix) 275 self.WriteTransfers(prefix) 276 277 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 278 data = source.ReadRangeSet(ranges) 279 ctx = sha1() 280 281 for p in data: 282 ctx.update(p) 283 284 return ctx.hexdigest() 285 286 def WriteTransfers(self, prefix): 287 out = [] 288 289 total = 0 290 performs_read = False 291 292 stashes = {} 293 stashed_blocks = 0 294 max_stashed_blocks = 0 295 296 free_stash_ids = [] 297 next_stash_id = 0 298 299 for xf in self.transfers: 300 301 if self.version < 2: 302 assert not xf.stash_before 303 assert not xf.use_stash 304 305 for s, sr in xf.stash_before: 306 assert s not in stashes 307 if free_stash_ids: 308 sid = heapq.heappop(free_stash_ids) 309 else: 310 sid = next_stash_id 311 next_stash_id += 1 312 stashes[s] = sid 313 stashed_blocks += sr.size() 314 if self.version == 2: 315 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 316 else: 317 sh = self.HashBlocks(self.src, sr) 318 if sh in stashes: 319 stashes[sh] += 1 320 else: 321 stashes[sh] = 1 322 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 323 324 if stashed_blocks > max_stashed_blocks: 325 max_stashed_blocks = stashed_blocks 326 327 free_string = [] 328 329 if self.version == 1: 330 src_str = xf.src_ranges.to_string_raw() 331 elif self.version >= 2: 332 333 # <# blocks> <src ranges> 334 # OR 335 # <# blocks> <src ranges> <src locs> <stash refs...> 336 # OR 337 # <# blocks> - <stash refs...> 338 339 size = xf.src_ranges.size() 340 src_str = [str(size)] 341 342 unstashed_src_ranges = xf.src_ranges 343 mapped_stashes = [] 344 for s, sr in xf.use_stash: 345 sid = stashes.pop(s) 346 stashed_blocks -= sr.size() 347 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 348 sh = self.HashBlocks(self.src, sr) 349 sr = xf.src_ranges.map_within(sr) 350 mapped_stashes.append(sr) 351 if self.version == 2: 352 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 353 else: 354 assert sh in stashes 355 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 356 stashes[sh] -= 1 357 if stashes[sh] == 0: 358 free_string.append("free %s\n" % (sh)) 359 stashes.pop(sh) 360 heapq.heappush(free_stash_ids, sid) 361 362 if unstashed_src_ranges: 363 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 364 if xf.use_stash: 365 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 366 src_str.insert(2, mapped_unstashed.to_string_raw()) 367 mapped_stashes.append(mapped_unstashed) 368 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 369 else: 370 src_str.insert(1, "-") 371 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 372 373 src_str = " ".join(src_str) 374 375 # all versions: 376 # zero <rangeset> 377 # new <rangeset> 378 # erase <rangeset> 379 # 380 # version 1: 381 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 382 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 383 # move <src rangeset> <tgt rangeset> 384 # 385 # version 2: 386 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 387 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 388 # move <tgt rangeset> <src_str> 389 # 390 # version 3: 391 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 392 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 393 # move hash <tgt rangeset> <src_str> 394 395 tgt_size = xf.tgt_ranges.size() 396 397 if xf.style == "new": 398 assert xf.tgt_ranges 399 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 400 total += tgt_size 401 elif xf.style == "move": 402 performs_read = True 403 assert xf.tgt_ranges 404 assert xf.src_ranges.size() == tgt_size 405 if xf.src_ranges != xf.tgt_ranges: 406 if self.version == 1: 407 out.append("%s %s %s\n" % ( 408 xf.style, 409 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 410 elif self.version == 2: 411 out.append("%s %s %s\n" % ( 412 xf.style, 413 xf.tgt_ranges.to_string_raw(), src_str)) 414 elif self.version >= 3: 415 # take into account automatic stashing of overlapping blocks 416 if xf.src_ranges.overlaps(xf.tgt_ranges): 417 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 418 if temp_stash_usage > max_stashed_blocks: 419 max_stashed_blocks = temp_stash_usage 420 421 out.append("%s %s %s %s\n" % ( 422 xf.style, 423 self.HashBlocks(self.tgt, xf.tgt_ranges), 424 xf.tgt_ranges.to_string_raw(), src_str)) 425 total += tgt_size 426 elif xf.style in ("bsdiff", "imgdiff"): 427 performs_read = True 428 assert xf.tgt_ranges 429 assert xf.src_ranges 430 if self.version == 1: 431 out.append("%s %d %d %s %s\n" % ( 432 xf.style, xf.patch_start, xf.patch_len, 433 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 434 elif self.version == 2: 435 out.append("%s %d %d %s %s\n" % ( 436 xf.style, xf.patch_start, xf.patch_len, 437 xf.tgt_ranges.to_string_raw(), src_str)) 438 elif self.version >= 3: 439 # take into account automatic stashing of overlapping blocks 440 if xf.src_ranges.overlaps(xf.tgt_ranges): 441 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 442 if temp_stash_usage > max_stashed_blocks: 443 max_stashed_blocks = temp_stash_usage 444 445 out.append("%s %d %d %s %s %s %s\n" % ( 446 xf.style, 447 xf.patch_start, xf.patch_len, 448 self.HashBlocks(self.src, xf.src_ranges), 449 self.HashBlocks(self.tgt, xf.tgt_ranges), 450 xf.tgt_ranges.to_string_raw(), src_str)) 451 total += tgt_size 452 elif xf.style == "zero": 453 assert xf.tgt_ranges 454 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 455 if to_zero: 456 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 457 total += to_zero.size() 458 else: 459 raise ValueError("unknown transfer style '%s'\n" % xf.style) 460 461 if free_string: 462 out.append("".join(free_string)) 463 464 if self.version >= 2 and common.OPTIONS.cache_size is not None: 465 # Sanity check: abort if we're going to need more stash space than 466 # the allowed size (cache_size * threshold). There are two purposes 467 # of having a threshold here. a) Part of the cache may have been 468 # occupied by some recovery logs. b) It will buy us some time to deal 469 # with the oversize issue. 470 cache_size = common.OPTIONS.cache_size 471 stash_threshold = common.OPTIONS.stash_threshold 472 max_allowed = cache_size * stash_threshold 473 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 474 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 475 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 476 self.tgt.blocksize, max_allowed, cache_size, 477 stash_threshold) 478 479 # Zero out extended blocks as a workaround for bug 20881595. 480 if self.tgt.extended: 481 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 482 483 # We erase all the blocks on the partition that a) don't contain useful 484 # data in the new image and b) will not be touched by dm-verity. 485 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 486 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 487 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 488 if new_dontcare: 489 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 490 491 out.insert(0, "%d\n" % (self.version,)) # format version number 492 out.insert(1, str(total) + "\n") 493 if self.version >= 2: 494 # version 2 only: after the total block count, we give the number 495 # of stash slots needed, and the maximum size needed (in blocks) 496 out.insert(2, str(next_stash_id) + "\n") 497 out.insert(3, str(max_stashed_blocks) + "\n") 498 499 with open(prefix + ".transfer.list", "wb") as f: 500 for i in out: 501 f.write(i) 502 503 if self.version >= 2: 504 max_stashed_size = max_stashed_blocks * self.tgt.blocksize 505 OPTIONS = common.OPTIONS 506 if OPTIONS.cache_size is not None: 507 max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold 508 print("max stashed blocks: %d (%d bytes), " 509 "limit: %d bytes (%.2f%%)\n" % ( 510 max_stashed_blocks, max_stashed_size, max_allowed, 511 max_stashed_size * 100.0 / max_allowed)) 512 else: 513 print("max stashed blocks: %d (%d bytes), limit: <unknown>\n" % ( 514 max_stashed_blocks, max_stashed_size)) 515 516 def ComputePatches(self, prefix): 517 print("Reticulating splines...") 518 diff_q = [] 519 patch_num = 0 520 with open(prefix + ".new.dat", "wb") as new_f: 521 for xf in self.transfers: 522 if xf.style == "zero": 523 pass 524 elif xf.style == "new": 525 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 526 new_f.write(piece) 527 elif xf.style == "diff": 528 src = self.src.ReadRangeSet(xf.src_ranges) 529 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 530 531 # We can't compare src and tgt directly because they may have 532 # the same content but be broken up into blocks differently, eg: 533 # 534 # ["he", "llo"] vs ["h", "ello"] 535 # 536 # We want those to compare equal, ideally without having to 537 # actually concatenate the strings (these may be tens of 538 # megabytes). 539 540 src_sha1 = sha1() 541 for p in src: 542 src_sha1.update(p) 543 tgt_sha1 = sha1() 544 tgt_size = 0 545 for p in tgt: 546 tgt_sha1.update(p) 547 tgt_size += len(p) 548 549 if src_sha1.digest() == tgt_sha1.digest(): 550 # These are identical; we don't need to generate a patch, 551 # just issue copy commands on the device. 552 xf.style = "move" 553 else: 554 # For files in zip format (eg, APKs, JARs, etc.) we would 555 # like to use imgdiff -z if possible (because it usually 556 # produces significantly smaller patches than bsdiff). 557 # This is permissible if: 558 # 559 # - the source and target files are monotonic (ie, the 560 # data is stored with blocks in increasing order), and 561 # - we haven't removed any blocks from the source set. 562 # 563 # If these conditions are satisfied then appending all the 564 # blocks in the set together in order will produce a valid 565 # zip file (plus possibly extra zeros in the last block), 566 # which is what imgdiff needs to operate. (imgdiff is 567 # fine with extra zeros at the end of the file.) 568 imgdiff = (xf.intact and 569 xf.tgt_name.split(".")[-1].lower() 570 in ("apk", "jar", "zip")) 571 xf.style = "imgdiff" if imgdiff else "bsdiff" 572 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 573 patch_num += 1 574 575 else: 576 assert False, "unknown style " + xf.style 577 578 if diff_q: 579 if self.threads > 1: 580 print("Computing patches (using %d threads)..." % (self.threads,)) 581 else: 582 print("Computing patches...") 583 diff_q.sort() 584 585 patches = [None] * patch_num 586 587 # TODO: Rewrite with multiprocessing.ThreadPool? 588 lock = threading.Lock() 589 def diff_worker(): 590 while True: 591 with lock: 592 if not diff_q: 593 return 594 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 595 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 596 size = len(patch) 597 with lock: 598 patches[patchnum] = (patch, xf) 599 print("%10d %10d (%6.2f%%) %7s %s" % ( 600 size, tgt_size, size * 100.0 / tgt_size, xf.style, 601 xf.tgt_name if xf.tgt_name == xf.src_name else ( 602 xf.tgt_name + " (from " + xf.src_name + ")"))) 603 604 threads = [threading.Thread(target=diff_worker) 605 for _ in range(self.threads)] 606 for th in threads: 607 th.start() 608 while threads: 609 threads.pop().join() 610 else: 611 patches = [] 612 613 p = 0 614 with open(prefix + ".patch.dat", "wb") as patch_f: 615 for patch, xf in patches: 616 xf.patch_start = p 617 xf.patch_len = len(patch) 618 patch_f.write(patch) 619 p += len(patch) 620 621 def AssertSequenceGood(self): 622 # Simulate the sequences of transfers we will output, and check that: 623 # - we never read a block after writing it, and 624 # - we write every block we care about exactly once. 625 626 # Start with no blocks having been touched yet. 627 touched = RangeSet() 628 629 # Imagine processing the transfers in order. 630 for xf in self.transfers: 631 # Check that the input blocks for this transfer haven't yet been touched. 632 633 x = xf.src_ranges 634 if self.version >= 2: 635 for _, sr in xf.use_stash: 636 x = x.subtract(sr) 637 638 assert not touched.overlaps(x) 639 # Check that the output blocks for this transfer haven't yet been touched. 640 assert not touched.overlaps(xf.tgt_ranges) 641 # Touch all the blocks written by this transfer. 642 touched = touched.union(xf.tgt_ranges) 643 644 # Check that we've written every target block. 645 assert touched == self.tgt.care_map 646 647 def ImproveVertexSequence(self): 648 print("Improving vertex order...") 649 650 # At this point our digraph is acyclic; we reversed any edges that 651 # were backwards in the heuristically-generated sequence. The 652 # previously-generated order is still acceptable, but we hope to 653 # find a better order that needs less memory for stashed data. 654 # Now we do a topological sort to generate a new vertex order, 655 # using a greedy algorithm to choose which vertex goes next 656 # whenever we have a choice. 657 658 # Make a copy of the edge set; this copy will get destroyed by the 659 # algorithm. 660 for xf in self.transfers: 661 xf.incoming = xf.goes_after.copy() 662 xf.outgoing = xf.goes_before.copy() 663 664 L = [] # the new vertex order 665 666 # S is the set of sources in the remaining graph; we always choose 667 # the one that leaves the least amount of stashed data after it's 668 # executed. 669 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 670 if not u.incoming] 671 heapq.heapify(S) 672 673 while S: 674 _, _, xf = heapq.heappop(S) 675 L.append(xf) 676 for u in xf.outgoing: 677 del u.incoming[xf] 678 if not u.incoming: 679 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 680 681 # if this fails then our graph had a cycle. 682 assert len(L) == len(self.transfers) 683 684 self.transfers = L 685 for i, xf in enumerate(L): 686 xf.order = i 687 688 def RemoveBackwardEdges(self): 689 print("Removing backward edges...") 690 in_order = 0 691 out_of_order = 0 692 lost_source = 0 693 694 for xf in self.transfers: 695 lost = 0 696 size = xf.src_ranges.size() 697 for u in xf.goes_before: 698 # xf should go before u 699 if xf.order < u.order: 700 # it does, hurray! 701 in_order += 1 702 else: 703 # it doesn't, boo. trim the blocks that u writes from xf's 704 # source, so that xf can go after u. 705 out_of_order += 1 706 assert xf.src_ranges.overlaps(u.tgt_ranges) 707 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 708 xf.intact = False 709 710 if xf.style == "diff" and not xf.src_ranges: 711 # nothing left to diff from; treat as new data 712 xf.style = "new" 713 714 lost = size - xf.src_ranges.size() 715 lost_source += lost 716 717 print((" %d/%d dependencies (%.2f%%) were violated; " 718 "%d source blocks removed.") % 719 (out_of_order, in_order + out_of_order, 720 (out_of_order * 100.0 / (in_order + out_of_order)) 721 if (in_order + out_of_order) else 0.0, 722 lost_source)) 723 724 def ReverseBackwardEdges(self): 725 print("Reversing backward edges...") 726 in_order = 0 727 out_of_order = 0 728 stashes = 0 729 stash_size = 0 730 731 for xf in self.transfers: 732 for u in xf.goes_before.copy(): 733 # xf should go before u 734 if xf.order < u.order: 735 # it does, hurray! 736 in_order += 1 737 else: 738 # it doesn't, boo. modify u to stash the blocks that it 739 # writes that xf wants to read, and then require u to go 740 # before xf. 741 out_of_order += 1 742 743 overlap = xf.src_ranges.intersect(u.tgt_ranges) 744 assert overlap 745 746 u.stash_before.append((stashes, overlap)) 747 xf.use_stash.append((stashes, overlap)) 748 stashes += 1 749 stash_size += overlap.size() 750 751 # reverse the edge direction; now xf must go after u 752 del xf.goes_before[u] 753 del u.goes_after[xf] 754 xf.goes_after[u] = None # value doesn't matter 755 u.goes_before[xf] = None 756 757 print((" %d/%d dependencies (%.2f%%) were violated; " 758 "%d source blocks stashed.") % 759 (out_of_order, in_order + out_of_order, 760 (out_of_order * 100.0 / (in_order + out_of_order)) 761 if (in_order + out_of_order) else 0.0, 762 stash_size)) 763 764 def FindVertexSequence(self): 765 print("Finding vertex sequence...") 766 767 # This is based on "A Fast & Effective Heuristic for the Feedback 768 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 769 # it as starting with the digraph G and moving all the vertices to 770 # be on a horizontal line in some order, trying to minimize the 771 # number of edges that end up pointing to the left. Left-pointing 772 # edges will get removed to turn the digraph into a DAG. In this 773 # case each edge has a weight which is the number of source blocks 774 # we'll lose if that edge is removed; we try to minimize the total 775 # weight rather than just the number of edges. 776 777 # Make a copy of the edge set; this copy will get destroyed by the 778 # algorithm. 779 for xf in self.transfers: 780 xf.incoming = xf.goes_after.copy() 781 xf.outgoing = xf.goes_before.copy() 782 783 # We use an OrderedDict instead of just a set so that the output 784 # is repeatable; otherwise it would depend on the hash values of 785 # the transfer objects. 786 G = OrderedDict() 787 for xf in self.transfers: 788 G[xf] = None 789 s1 = deque() # the left side of the sequence, built from left to right 790 s2 = deque() # the right side of the sequence, built from right to left 791 792 while G: 793 794 # Put all sinks at the end of the sequence. 795 while True: 796 sinks = [u for u in G if not u.outgoing] 797 if not sinks: 798 break 799 for u in sinks: 800 s2.appendleft(u) 801 del G[u] 802 for iu in u.incoming: 803 del iu.outgoing[u] 804 805 # Put all the sources at the beginning of the sequence. 806 while True: 807 sources = [u for u in G if not u.incoming] 808 if not sources: 809 break 810 for u in sources: 811 s1.append(u) 812 del G[u] 813 for iu in u.outgoing: 814 del iu.incoming[u] 815 816 if not G: 817 break 818 819 # Find the "best" vertex to put next. "Best" is the one that 820 # maximizes the net difference in source blocks saved we get by 821 # pretending it's a source rather than a sink. 822 823 max_d = None 824 best_u = None 825 for u in G: 826 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 827 if best_u is None or d > max_d: 828 max_d = d 829 best_u = u 830 831 u = best_u 832 s1.append(u) 833 del G[u] 834 for iu in u.outgoing: 835 del iu.incoming[u] 836 for iu in u.incoming: 837 del iu.outgoing[u] 838 839 # Now record the sequence in the 'order' field of each transfer, 840 # and by rearranging self.transfers to be in the chosen sequence. 841 842 new_transfers = [] 843 for x in itertools.chain(s1, s2): 844 x.order = len(new_transfers) 845 new_transfers.append(x) 846 del x.incoming 847 del x.outgoing 848 849 self.transfers = new_transfers 850 851 def GenerateDigraph(self): 852 print("Generating digraph...") 853 for a in self.transfers: 854 for b in self.transfers: 855 if a is b: 856 continue 857 858 # If the blocks written by A are read by B, then B needs to go before A. 859 i = a.tgt_ranges.intersect(b.src_ranges) 860 if i: 861 if b.src_name == "__ZERO": 862 # the cost of removing source blocks for the __ZERO domain 863 # is (nearly) zero. 864 size = 0 865 else: 866 size = i.size() 867 b.goes_before[a] = size 868 a.goes_after[b] = size 869 870 def FindTransfers(self): 871 empty = RangeSet() 872 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 873 if tgt_fn == "__ZERO": 874 # the special "__ZERO" domain is all the blocks not contained 875 # in any file and that are filled with zeros. We have a 876 # special transfer style for zero blocks. 877 src_ranges = self.src.file_map.get("__ZERO", empty) 878 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 879 "zero", self.transfers) 880 continue 881 882 elif tgt_fn == "__COPY": 883 # "__COPY" domain includes all the blocks not contained in any 884 # file and that need to be copied unconditionally to the target. 885 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 886 continue 887 888 elif tgt_fn in self.src.file_map: 889 # Look for an exact pathname match in the source. 890 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 891 "diff", self.transfers) 892 continue 893 894 b = os.path.basename(tgt_fn) 895 if b in self.src_basenames: 896 # Look for an exact basename match in the source. 897 src_fn = self.src_basenames[b] 898 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 899 "diff", self.transfers) 900 continue 901 902 b = re.sub("[0-9]+", "#", b) 903 if b in self.src_numpatterns: 904 # Look for a 'number pattern' match (a basename match after 905 # all runs of digits are replaced by "#"). (This is useful 906 # for .so files that contain version numbers in the filename 907 # that get bumped.) 908 src_fn = self.src_numpatterns[b] 909 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 910 "diff", self.transfers) 911 continue 912 913 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 914 915 def AbbreviateSourceNames(self): 916 for k in self.src.file_map.keys(): 917 b = os.path.basename(k) 918 self.src_basenames[b] = k 919 b = re.sub("[0-9]+", "#", b) 920 self.src_numpatterns[b] = k 921 922 @staticmethod 923 def AssertPartition(total, seq): 924 """Assert that all the RangeSets in 'seq' form a partition of the 925 'total' RangeSet (ie, they are nonintersecting and their union 926 equals 'total').""" 927 so_far = RangeSet() 928 for i in seq: 929 assert not so_far.overlaps(i) 930 so_far = so_far.union(i) 931 assert so_far == total 932