blockimgdiff.py revision d47d8e14880132c42a75f41c8041851797c75e35
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import common 20import heapq 21import itertools 22import multiprocessing 23import os 24import re 25import subprocess 26import threading 27import tempfile 28 29from rangelib import RangeSet 30 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34 35def compute_patch(src, tgt, imgdiff=False): 36 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 37 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 38 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 39 os.close(patchfd) 40 41 try: 42 with os.fdopen(srcfd, "wb") as f_src: 43 for p in src: 44 f_src.write(p) 45 46 with os.fdopen(tgtfd, "wb") as f_tgt: 47 for p in tgt: 48 f_tgt.write(p) 49 try: 50 os.unlink(patchfile) 51 except OSError: 52 pass 53 if imgdiff: 54 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 55 stdout=open("/dev/null", "a"), 56 stderr=subprocess.STDOUT) 57 else: 58 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 59 60 if p: 61 raise ValueError("diff failed: " + str(p)) 62 63 with open(patchfile, "rb") as f: 64 return f.read() 65 finally: 66 try: 67 os.unlink(srcfile) 68 os.unlink(tgtfile) 69 os.unlink(patchfile) 70 except OSError: 71 pass 72 73 74class Image(object): 75 def ReadRangeSet(self, ranges): 76 raise NotImplementedError 77 78 def TotalSha1(self, include_clobbered_blocks=False): 79 raise NotImplementedError 80 81 82class EmptyImage(Image): 83 """A zero-length image.""" 84 blocksize = 4096 85 care_map = RangeSet() 86 clobbered_blocks = RangeSet() 87 extended = RangeSet() 88 total_blocks = 0 89 file_map = {} 90 def ReadRangeSet(self, ranges): 91 return () 92 def TotalSha1(self, include_clobbered_blocks=False): 93 # EmptyImage always carries empty clobbered_blocks, so 94 # include_clobbered_blocks can be ignored. 95 assert self.clobbered_blocks.size() == 0 96 return sha1().hexdigest() 97 98 99class DataImage(Image): 100 """An image wrapped around a single string of data.""" 101 102 def __init__(self, data, trim=False, pad=False): 103 self.data = data 104 self.blocksize = 4096 105 106 assert not (trim and pad) 107 108 partial = len(self.data) % self.blocksize 109 if partial > 0: 110 if trim: 111 self.data = self.data[:-partial] 112 elif pad: 113 self.data += '\0' * (self.blocksize - partial) 114 else: 115 raise ValueError(("data for DataImage must be multiple of %d bytes " 116 "unless trim or pad is specified") % 117 (self.blocksize,)) 118 119 assert len(self.data) % self.blocksize == 0 120 121 self.total_blocks = len(self.data) / self.blocksize 122 self.care_map = RangeSet(data=(0, self.total_blocks)) 123 self.clobbered_blocks = RangeSet() 124 self.extended = RangeSet() 125 126 zero_blocks = [] 127 nonzero_blocks = [] 128 reference = '\0' * self.blocksize 129 130 for i in range(self.total_blocks): 131 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 132 if d == reference: 133 zero_blocks.append(i) 134 zero_blocks.append(i+1) 135 else: 136 nonzero_blocks.append(i) 137 nonzero_blocks.append(i+1) 138 139 self.file_map = {"__ZERO": RangeSet(zero_blocks), 140 "__NONZERO": RangeSet(nonzero_blocks)} 141 142 def ReadRangeSet(self, ranges): 143 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 144 145 def TotalSha1(self, include_clobbered_blocks=False): 146 # DataImage always carries empty clobbered_blocks, so 147 # include_clobbered_blocks can be ignored. 148 assert self.clobbered_blocks.size() == 0 149 return sha1(self.data).hexdigest() 150 151 152class Transfer(object): 153 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 154 self.tgt_name = tgt_name 155 self.src_name = src_name 156 self.tgt_ranges = tgt_ranges 157 self.src_ranges = src_ranges 158 self.style = style 159 self.intact = (getattr(tgt_ranges, "monotonic", False) and 160 getattr(src_ranges, "monotonic", False)) 161 162 # We use OrderedDict rather than dict so that the output is repeatable; 163 # otherwise it would depend on the hash values of the Transfer objects. 164 self.goes_before = OrderedDict() 165 self.goes_after = OrderedDict() 166 167 self.stash_before = [] 168 self.use_stash = [] 169 170 self.id = len(by_id) 171 by_id.append(self) 172 173 def NetStashChange(self): 174 return (sum(sr.size() for (_, sr) in self.stash_before) - 175 sum(sr.size() for (_, sr) in self.use_stash)) 176 177 def __str__(self): 178 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 179 " to " + str(self.tgt_ranges) + ">") 180 181 182# BlockImageDiff works on two image objects. An image object is 183# anything that provides the following attributes: 184# 185# blocksize: the size in bytes of a block, currently must be 4096. 186# 187# total_blocks: the total size of the partition/image, in blocks. 188# 189# care_map: a RangeSet containing which blocks (in the range [0, 190# total_blocks) we actually care about; i.e. which blocks contain 191# data. 192# 193# file_map: a dict that partitions the blocks contained in care_map 194# into smaller domains that are useful for doing diffs on. 195# (Typically a domain is a file, and the key in file_map is the 196# pathname.) 197# 198# clobbered_blocks: a RangeSet containing which blocks contain data 199# but may be altered by the FS. They need to be excluded when 200# verifying the partition integrity. 201# 202# ReadRangeSet(): a function that takes a RangeSet and returns the 203# data contained in the image blocks of that RangeSet. The data 204# is returned as a list or tuple of strings; concatenating the 205# elements together should produce the requested data. 206# Implementations are free to break up the data into list/tuple 207# elements in any way that is convenient. 208# 209# TotalSha1(): a function that returns (as a hex string) the SHA-1 210# hash of all the data in the image (ie, all the blocks in the 211# care_map minus clobbered_blocks, or including the clobbered 212# blocks if include_clobbered_blocks is True). 213# 214# When creating a BlockImageDiff, the src image may be None, in which 215# case the list of transfers produced will never read from the 216# original image. 217 218class BlockImageDiff(object): 219 def __init__(self, tgt, src=None, threads=None, version=3): 220 if threads is None: 221 threads = multiprocessing.cpu_count() // 2 222 if threads == 0: 223 threads = 1 224 self.threads = threads 225 self.version = version 226 self.transfers = [] 227 self.src_basenames = {} 228 self.src_numpatterns = {} 229 230 assert version in (1, 2, 3) 231 232 self.tgt = tgt 233 if src is None: 234 src = EmptyImage() 235 self.src = src 236 237 # The updater code that installs the patch always uses 4k blocks. 238 assert tgt.blocksize == 4096 239 assert src.blocksize == 4096 240 241 # The range sets in each filemap should comprise a partition of 242 # the care map. 243 self.AssertPartition(src.care_map, src.file_map.values()) 244 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 245 246 def Compute(self, prefix): 247 # When looking for a source file to use as the diff input for a 248 # target file, we try: 249 # 1) an exact path match if available, otherwise 250 # 2) a exact basename match if available, otherwise 251 # 3) a basename match after all runs of digits are replaced by 252 # "#" if available, otherwise 253 # 4) we have no source for this target. 254 self.AbbreviateSourceNames() 255 self.FindTransfers() 256 257 # Find the ordering dependencies among transfers (this is O(n^2) 258 # in the number of transfers). 259 self.GenerateDigraph() 260 # Find a sequence of transfers that satisfies as many ordering 261 # dependencies as possible (heuristically). 262 self.FindVertexSequence() 263 # Fix up the ordering dependencies that the sequence didn't 264 # satisfy. 265 if self.version == 1: 266 self.RemoveBackwardEdges() 267 else: 268 self.ReverseBackwardEdges() 269 self.ImproveVertexSequence() 270 271 # Double-check our work. 272 self.AssertSequenceGood() 273 274 self.ComputePatches(prefix) 275 self.WriteTransfers(prefix) 276 277 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 278 data = source.ReadRangeSet(ranges) 279 ctx = sha1() 280 281 for p in data: 282 ctx.update(p) 283 284 return ctx.hexdigest() 285 286 def WriteTransfers(self, prefix): 287 out = [] 288 289 total = 0 290 performs_read = False 291 292 stashes = {} 293 stashed_blocks = 0 294 max_stashed_blocks = 0 295 296 free_stash_ids = [] 297 next_stash_id = 0 298 299 for xf in self.transfers: 300 301 if self.version < 2: 302 assert not xf.stash_before 303 assert not xf.use_stash 304 305 for s, sr in xf.stash_before: 306 assert s not in stashes 307 if free_stash_ids: 308 sid = heapq.heappop(free_stash_ids) 309 else: 310 sid = next_stash_id 311 next_stash_id += 1 312 stashes[s] = sid 313 stashed_blocks += sr.size() 314 if self.version == 2: 315 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 316 else: 317 sh = self.HashBlocks(self.src, sr) 318 if sh in stashes: 319 stashes[sh] += 1 320 else: 321 stashes[sh] = 1 322 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 323 324 if stashed_blocks > max_stashed_blocks: 325 max_stashed_blocks = stashed_blocks 326 327 free_string = [] 328 329 if self.version == 1: 330 src_str = xf.src_ranges.to_string_raw() 331 elif self.version >= 2: 332 333 # <# blocks> <src ranges> 334 # OR 335 # <# blocks> <src ranges> <src locs> <stash refs...> 336 # OR 337 # <# blocks> - <stash refs...> 338 339 size = xf.src_ranges.size() 340 src_str = [str(size)] 341 342 unstashed_src_ranges = xf.src_ranges 343 mapped_stashes = [] 344 for s, sr in xf.use_stash: 345 sid = stashes.pop(s) 346 stashed_blocks -= sr.size() 347 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 348 sh = self.HashBlocks(self.src, sr) 349 sr = xf.src_ranges.map_within(sr) 350 mapped_stashes.append(sr) 351 if self.version == 2: 352 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 353 else: 354 assert sh in stashes 355 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 356 stashes[sh] -= 1 357 if stashes[sh] == 0: 358 free_string.append("free %s\n" % (sh)) 359 stashes.pop(sh) 360 heapq.heappush(free_stash_ids, sid) 361 362 if unstashed_src_ranges: 363 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 364 if xf.use_stash: 365 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 366 src_str.insert(2, mapped_unstashed.to_string_raw()) 367 mapped_stashes.append(mapped_unstashed) 368 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 369 else: 370 src_str.insert(1, "-") 371 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 372 373 src_str = " ".join(src_str) 374 375 # all versions: 376 # zero <rangeset> 377 # new <rangeset> 378 # erase <rangeset> 379 # 380 # version 1: 381 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 382 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 383 # move <src rangeset> <tgt rangeset> 384 # 385 # version 2: 386 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 387 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 388 # move <tgt rangeset> <src_str> 389 # 390 # version 3: 391 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 392 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 393 # move hash <tgt rangeset> <src_str> 394 395 tgt_size = xf.tgt_ranges.size() 396 397 if xf.style == "new": 398 assert xf.tgt_ranges 399 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 400 total += tgt_size 401 elif xf.style == "move": 402 performs_read = True 403 assert xf.tgt_ranges 404 assert xf.src_ranges.size() == tgt_size 405 if xf.src_ranges != xf.tgt_ranges: 406 if self.version == 1: 407 out.append("%s %s %s\n" % ( 408 xf.style, 409 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 410 elif self.version == 2: 411 out.append("%s %s %s\n" % ( 412 xf.style, 413 xf.tgt_ranges.to_string_raw(), src_str)) 414 elif self.version >= 3: 415 # take into account automatic stashing of overlapping blocks 416 if xf.src_ranges.overlaps(xf.tgt_ranges): 417 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 418 if temp_stash_usage > max_stashed_blocks: 419 max_stashed_blocks = temp_stash_usage 420 421 out.append("%s %s %s %s\n" % ( 422 xf.style, 423 self.HashBlocks(self.tgt, xf.tgt_ranges), 424 xf.tgt_ranges.to_string_raw(), src_str)) 425 total += tgt_size 426 elif xf.style in ("bsdiff", "imgdiff"): 427 performs_read = True 428 assert xf.tgt_ranges 429 assert xf.src_ranges 430 if self.version == 1: 431 out.append("%s %d %d %s %s\n" % ( 432 xf.style, xf.patch_start, xf.patch_len, 433 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 434 elif self.version == 2: 435 out.append("%s %d %d %s %s\n" % ( 436 xf.style, xf.patch_start, xf.patch_len, 437 xf.tgt_ranges.to_string_raw(), src_str)) 438 elif self.version >= 3: 439 # take into account automatic stashing of overlapping blocks 440 if xf.src_ranges.overlaps(xf.tgt_ranges): 441 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 442 if temp_stash_usage > max_stashed_blocks: 443 max_stashed_blocks = temp_stash_usage 444 445 out.append("%s %d %d %s %s %s %s\n" % ( 446 xf.style, 447 xf.patch_start, xf.patch_len, 448 self.HashBlocks(self.src, xf.src_ranges), 449 self.HashBlocks(self.tgt, xf.tgt_ranges), 450 xf.tgt_ranges.to_string_raw(), src_str)) 451 total += tgt_size 452 elif xf.style == "zero": 453 assert xf.tgt_ranges 454 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 455 if to_zero: 456 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 457 total += to_zero.size() 458 else: 459 raise ValueError("unknown transfer style '%s'\n" % xf.style) 460 461 if free_string: 462 out.append("".join(free_string)) 463 464 if self.version >= 2: 465 # Sanity check: abort if we're going to need more stash space than 466 # the allowed size (cache_size * threshold). There are two purposes 467 # of having a threshold here. a) Part of the cache may have been 468 # occupied by some recovery logs. b) It will buy us some time to deal 469 # with the oversize issue. 470 cache_size = common.OPTIONS.cache_size 471 stash_threshold = common.OPTIONS.stash_threshold 472 max_allowed = cache_size * stash_threshold 473 assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \ 474 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % ( 475 max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks, 476 self.tgt.blocksize, max_allowed, cache_size, 477 stash_threshold) 478 479 # Zero out extended blocks as a workaround for bug 20881595. 480 if self.tgt.extended: 481 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 482 483 # We erase all the blocks on the partition that a) don't contain useful 484 # data in the new image and b) will not be touched by dm-verity. 485 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 486 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 487 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 488 if new_dontcare: 489 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 490 491 out.insert(0, "%d\n" % (self.version,)) # format version number 492 out.insert(1, str(total) + "\n") 493 if self.version >= 2: 494 # version 2 only: after the total block count, we give the number 495 # of stash slots needed, and the maximum size needed (in blocks) 496 out.insert(2, str(next_stash_id) + "\n") 497 out.insert(3, str(max_stashed_blocks) + "\n") 498 499 with open(prefix + ".transfer.list", "wb") as f: 500 for i in out: 501 f.write(i) 502 503 if self.version >= 2: 504 max_stashed_size = max_stashed_blocks * self.tgt.blocksize 505 max_allowed = common.OPTIONS.cache_size * common.OPTIONS.stash_threshold 506 print("max stashed blocks: %d (%d bytes), limit: %d bytes (%.2f%%)\n" % ( 507 max_stashed_blocks, max_stashed_size, max_allowed, 508 max_stashed_size * 100.0 / max_allowed)) 509 510 def ComputePatches(self, prefix): 511 print("Reticulating splines...") 512 diff_q = [] 513 patch_num = 0 514 with open(prefix + ".new.dat", "wb") as new_f: 515 for xf in self.transfers: 516 if xf.style == "zero": 517 pass 518 elif xf.style == "new": 519 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 520 new_f.write(piece) 521 elif xf.style == "diff": 522 src = self.src.ReadRangeSet(xf.src_ranges) 523 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 524 525 # We can't compare src and tgt directly because they may have 526 # the same content but be broken up into blocks differently, eg: 527 # 528 # ["he", "llo"] vs ["h", "ello"] 529 # 530 # We want those to compare equal, ideally without having to 531 # actually concatenate the strings (these may be tens of 532 # megabytes). 533 534 src_sha1 = sha1() 535 for p in src: 536 src_sha1.update(p) 537 tgt_sha1 = sha1() 538 tgt_size = 0 539 for p in tgt: 540 tgt_sha1.update(p) 541 tgt_size += len(p) 542 543 if src_sha1.digest() == tgt_sha1.digest(): 544 # These are identical; we don't need to generate a patch, 545 # just issue copy commands on the device. 546 xf.style = "move" 547 else: 548 # For files in zip format (eg, APKs, JARs, etc.) we would 549 # like to use imgdiff -z if possible (because it usually 550 # produces significantly smaller patches than bsdiff). 551 # This is permissible if: 552 # 553 # - the source and target files are monotonic (ie, the 554 # data is stored with blocks in increasing order), and 555 # - we haven't removed any blocks from the source set. 556 # 557 # If these conditions are satisfied then appending all the 558 # blocks in the set together in order will produce a valid 559 # zip file (plus possibly extra zeros in the last block), 560 # which is what imgdiff needs to operate. (imgdiff is 561 # fine with extra zeros at the end of the file.) 562 imgdiff = (xf.intact and 563 xf.tgt_name.split(".")[-1].lower() 564 in ("apk", "jar", "zip")) 565 xf.style = "imgdiff" if imgdiff else "bsdiff" 566 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 567 patch_num += 1 568 569 else: 570 assert False, "unknown style " + xf.style 571 572 if diff_q: 573 if self.threads > 1: 574 print("Computing patches (using %d threads)..." % (self.threads,)) 575 else: 576 print("Computing patches...") 577 diff_q.sort() 578 579 patches = [None] * patch_num 580 581 # TODO: Rewrite with multiprocessing.ThreadPool? 582 lock = threading.Lock() 583 def diff_worker(): 584 while True: 585 with lock: 586 if not diff_q: 587 return 588 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 589 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 590 size = len(patch) 591 with lock: 592 patches[patchnum] = (patch, xf) 593 print("%10d %10d (%6.2f%%) %7s %s" % ( 594 size, tgt_size, size * 100.0 / tgt_size, xf.style, 595 xf.tgt_name if xf.tgt_name == xf.src_name else ( 596 xf.tgt_name + " (from " + xf.src_name + ")"))) 597 598 threads = [threading.Thread(target=diff_worker) 599 for _ in range(self.threads)] 600 for th in threads: 601 th.start() 602 while threads: 603 threads.pop().join() 604 else: 605 patches = [] 606 607 p = 0 608 with open(prefix + ".patch.dat", "wb") as patch_f: 609 for patch, xf in patches: 610 xf.patch_start = p 611 xf.patch_len = len(patch) 612 patch_f.write(patch) 613 p += len(patch) 614 615 def AssertSequenceGood(self): 616 # Simulate the sequences of transfers we will output, and check that: 617 # - we never read a block after writing it, and 618 # - we write every block we care about exactly once. 619 620 # Start with no blocks having been touched yet. 621 touched = RangeSet() 622 623 # Imagine processing the transfers in order. 624 for xf in self.transfers: 625 # Check that the input blocks for this transfer haven't yet been touched. 626 627 x = xf.src_ranges 628 if self.version >= 2: 629 for _, sr in xf.use_stash: 630 x = x.subtract(sr) 631 632 assert not touched.overlaps(x) 633 # Check that the output blocks for this transfer haven't yet been touched. 634 assert not touched.overlaps(xf.tgt_ranges) 635 # Touch all the blocks written by this transfer. 636 touched = touched.union(xf.tgt_ranges) 637 638 # Check that we've written every target block. 639 assert touched == self.tgt.care_map 640 641 def ImproveVertexSequence(self): 642 print("Improving vertex order...") 643 644 # At this point our digraph is acyclic; we reversed any edges that 645 # were backwards in the heuristically-generated sequence. The 646 # previously-generated order is still acceptable, but we hope to 647 # find a better order that needs less memory for stashed data. 648 # Now we do a topological sort to generate a new vertex order, 649 # using a greedy algorithm to choose which vertex goes next 650 # whenever we have a choice. 651 652 # Make a copy of the edge set; this copy will get destroyed by the 653 # algorithm. 654 for xf in self.transfers: 655 xf.incoming = xf.goes_after.copy() 656 xf.outgoing = xf.goes_before.copy() 657 658 L = [] # the new vertex order 659 660 # S is the set of sources in the remaining graph; we always choose 661 # the one that leaves the least amount of stashed data after it's 662 # executed. 663 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 664 if not u.incoming] 665 heapq.heapify(S) 666 667 while S: 668 _, _, xf = heapq.heappop(S) 669 L.append(xf) 670 for u in xf.outgoing: 671 del u.incoming[xf] 672 if not u.incoming: 673 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 674 675 # if this fails then our graph had a cycle. 676 assert len(L) == len(self.transfers) 677 678 self.transfers = L 679 for i, xf in enumerate(L): 680 xf.order = i 681 682 def RemoveBackwardEdges(self): 683 print("Removing backward edges...") 684 in_order = 0 685 out_of_order = 0 686 lost_source = 0 687 688 for xf in self.transfers: 689 lost = 0 690 size = xf.src_ranges.size() 691 for u in xf.goes_before: 692 # xf should go before u 693 if xf.order < u.order: 694 # it does, hurray! 695 in_order += 1 696 else: 697 # it doesn't, boo. trim the blocks that u writes from xf's 698 # source, so that xf can go after u. 699 out_of_order += 1 700 assert xf.src_ranges.overlaps(u.tgt_ranges) 701 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 702 xf.intact = False 703 704 if xf.style == "diff" and not xf.src_ranges: 705 # nothing left to diff from; treat as new data 706 xf.style = "new" 707 708 lost = size - xf.src_ranges.size() 709 lost_source += lost 710 711 print((" %d/%d dependencies (%.2f%%) were violated; " 712 "%d source blocks removed.") % 713 (out_of_order, in_order + out_of_order, 714 (out_of_order * 100.0 / (in_order + out_of_order)) 715 if (in_order + out_of_order) else 0.0, 716 lost_source)) 717 718 def ReverseBackwardEdges(self): 719 print("Reversing backward edges...") 720 in_order = 0 721 out_of_order = 0 722 stashes = 0 723 stash_size = 0 724 725 for xf in self.transfers: 726 for u in xf.goes_before.copy(): 727 # xf should go before u 728 if xf.order < u.order: 729 # it does, hurray! 730 in_order += 1 731 else: 732 # it doesn't, boo. modify u to stash the blocks that it 733 # writes that xf wants to read, and then require u to go 734 # before xf. 735 out_of_order += 1 736 737 overlap = xf.src_ranges.intersect(u.tgt_ranges) 738 assert overlap 739 740 u.stash_before.append((stashes, overlap)) 741 xf.use_stash.append((stashes, overlap)) 742 stashes += 1 743 stash_size += overlap.size() 744 745 # reverse the edge direction; now xf must go after u 746 del xf.goes_before[u] 747 del u.goes_after[xf] 748 xf.goes_after[u] = None # value doesn't matter 749 u.goes_before[xf] = None 750 751 print((" %d/%d dependencies (%.2f%%) were violated; " 752 "%d source blocks stashed.") % 753 (out_of_order, in_order + out_of_order, 754 (out_of_order * 100.0 / (in_order + out_of_order)) 755 if (in_order + out_of_order) else 0.0, 756 stash_size)) 757 758 def FindVertexSequence(self): 759 print("Finding vertex sequence...") 760 761 # This is based on "A Fast & Effective Heuristic for the Feedback 762 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 763 # it as starting with the digraph G and moving all the vertices to 764 # be on a horizontal line in some order, trying to minimize the 765 # number of edges that end up pointing to the left. Left-pointing 766 # edges will get removed to turn the digraph into a DAG. In this 767 # case each edge has a weight which is the number of source blocks 768 # we'll lose if that edge is removed; we try to minimize the total 769 # weight rather than just the number of edges. 770 771 # Make a copy of the edge set; this copy will get destroyed by the 772 # algorithm. 773 for xf in self.transfers: 774 xf.incoming = xf.goes_after.copy() 775 xf.outgoing = xf.goes_before.copy() 776 777 # We use an OrderedDict instead of just a set so that the output 778 # is repeatable; otherwise it would depend on the hash values of 779 # the transfer objects. 780 G = OrderedDict() 781 for xf in self.transfers: 782 G[xf] = None 783 s1 = deque() # the left side of the sequence, built from left to right 784 s2 = deque() # the right side of the sequence, built from right to left 785 786 while G: 787 788 # Put all sinks at the end of the sequence. 789 while True: 790 sinks = [u for u in G if not u.outgoing] 791 if not sinks: 792 break 793 for u in sinks: 794 s2.appendleft(u) 795 del G[u] 796 for iu in u.incoming: 797 del iu.outgoing[u] 798 799 # Put all the sources at the beginning of the sequence. 800 while True: 801 sources = [u for u in G if not u.incoming] 802 if not sources: 803 break 804 for u in sources: 805 s1.append(u) 806 del G[u] 807 for iu in u.outgoing: 808 del iu.incoming[u] 809 810 if not G: 811 break 812 813 # Find the "best" vertex to put next. "Best" is the one that 814 # maximizes the net difference in source blocks saved we get by 815 # pretending it's a source rather than a sink. 816 817 max_d = None 818 best_u = None 819 for u in G: 820 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 821 if best_u is None or d > max_d: 822 max_d = d 823 best_u = u 824 825 u = best_u 826 s1.append(u) 827 del G[u] 828 for iu in u.outgoing: 829 del iu.incoming[u] 830 for iu in u.incoming: 831 del iu.outgoing[u] 832 833 # Now record the sequence in the 'order' field of each transfer, 834 # and by rearranging self.transfers to be in the chosen sequence. 835 836 new_transfers = [] 837 for x in itertools.chain(s1, s2): 838 x.order = len(new_transfers) 839 new_transfers.append(x) 840 del x.incoming 841 del x.outgoing 842 843 self.transfers = new_transfers 844 845 def GenerateDigraph(self): 846 print("Generating digraph...") 847 for a in self.transfers: 848 for b in self.transfers: 849 if a is b: 850 continue 851 852 # If the blocks written by A are read by B, then B needs to go before A. 853 i = a.tgt_ranges.intersect(b.src_ranges) 854 if i: 855 if b.src_name == "__ZERO": 856 # the cost of removing source blocks for the __ZERO domain 857 # is (nearly) zero. 858 size = 0 859 else: 860 size = i.size() 861 b.goes_before[a] = size 862 a.goes_after[b] = size 863 864 def FindTransfers(self): 865 empty = RangeSet() 866 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 867 if tgt_fn == "__ZERO": 868 # the special "__ZERO" domain is all the blocks not contained 869 # in any file and that are filled with zeros. We have a 870 # special transfer style for zero blocks. 871 src_ranges = self.src.file_map.get("__ZERO", empty) 872 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 873 "zero", self.transfers) 874 continue 875 876 elif tgt_fn == "__COPY": 877 # "__COPY" domain includes all the blocks not contained in any 878 # file and that need to be copied unconditionally to the target. 879 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 880 continue 881 882 elif tgt_fn in self.src.file_map: 883 # Look for an exact pathname match in the source. 884 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 885 "diff", self.transfers) 886 continue 887 888 b = os.path.basename(tgt_fn) 889 if b in self.src_basenames: 890 # Look for an exact basename match in the source. 891 src_fn = self.src_basenames[b] 892 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 893 "diff", self.transfers) 894 continue 895 896 b = re.sub("[0-9]+", "#", b) 897 if b in self.src_numpatterns: 898 # Look for a 'number pattern' match (a basename match after 899 # all runs of digits are replaced by "#"). (This is useful 900 # for .so files that contain version numbers in the filename 901 # that get bumped.) 902 src_fn = self.src_numpatterns[b] 903 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 904 "diff", self.transfers) 905 continue 906 907 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 908 909 def AbbreviateSourceNames(self): 910 for k in self.src.file_map.keys(): 911 b = os.path.basename(k) 912 self.src_basenames[b] = k 913 b = re.sub("[0-9]+", "#", b) 914 self.src_numpatterns[b] = k 915 916 @staticmethod 917 def AssertPartition(total, seq): 918 """Assert that all the RangeSets in 'seq' form a partition of the 919 'total' RangeSet (ie, they are nonintersecting and their union 920 equals 'total').""" 921 so_far = RangeSet() 922 for i in seq: 923 assert not so_far.overlaps(i) 924 so_far = so_far.union(i) 925 assert so_far == total 926