blockimgdiff.py revision 623381880a32a2912f95949a6c406038ac4e7064
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import pprint 24import re 25import subprocess 26import sys 27import threading 28import tempfile 29 30from rangelib import * 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72class EmptyImage(object): 73 """A zero-length image.""" 74 blocksize = 4096 75 care_map = RangeSet() 76 total_blocks = 0 77 file_map = {} 78 def ReadRangeSet(self, ranges): 79 return () 80 def TotalSha1(self): 81 return sha1().hexdigest() 82 83 84class DataImage(object): 85 """An image wrapped around a single string of data.""" 86 87 def __init__(self, data, trim=False, pad=False): 88 self.data = data 89 self.blocksize = 4096 90 91 assert not (trim and pad) 92 93 partial = len(self.data) % self.blocksize 94 if partial > 0: 95 if trim: 96 self.data = self.data[:-partial] 97 elif pad: 98 self.data += '\0' * (self.blocksize - partial) 99 else: 100 raise ValueError(("data for DataImage must be multiple of %d bytes " 101 "unless trim or pad is specified") % 102 (self.blocksize,)) 103 104 assert len(self.data) % self.blocksize == 0 105 106 self.total_blocks = len(self.data) / self.blocksize 107 self.care_map = RangeSet(data=(0, self.total_blocks)) 108 109 zero_blocks = [] 110 nonzero_blocks = [] 111 reference = '\0' * self.blocksize 112 113 for i in range(self.total_blocks): 114 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 115 if d == reference: 116 zero_blocks.append(i) 117 zero_blocks.append(i+1) 118 else: 119 nonzero_blocks.append(i) 120 nonzero_blocks.append(i+1) 121 122 self.file_map = {"__ZERO": RangeSet(zero_blocks), 123 "__NONZERO": RangeSet(nonzero_blocks)} 124 125 def ReadRangeSet(self, ranges): 126 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 127 128 def TotalSha1(self): 129 if not hasattr(self, "sha1"): 130 self.sha1 = sha1(self.data).hexdigest() 131 return self.sha1 132 133 134class Transfer(object): 135 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 136 self.tgt_name = tgt_name 137 self.src_name = src_name 138 self.tgt_ranges = tgt_ranges 139 self.src_ranges = src_ranges 140 self.style = style 141 self.intact = (getattr(tgt_ranges, "monotonic", False) and 142 getattr(src_ranges, "monotonic", False)) 143 self.goes_before = {} 144 self.goes_after = {} 145 146 self.stash_before = [] 147 self.use_stash = [] 148 149 self.id = len(by_id) 150 by_id.append(self) 151 152 def NetStashChange(self): 153 return (sum(sr.size() for (_, sr) in self.stash_before) - 154 sum(sr.size() for (_, sr) in self.use_stash)) 155 156 def __str__(self): 157 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 158 " to " + str(self.tgt_ranges) + ">") 159 160 161# BlockImageDiff works on two image objects. An image object is 162# anything that provides the following attributes: 163# 164# blocksize: the size in bytes of a block, currently must be 4096. 165# 166# total_blocks: the total size of the partition/image, in blocks. 167# 168# care_map: a RangeSet containing which blocks (in the range [0, 169# total_blocks) we actually care about; i.e. which blocks contain 170# data. 171# 172# file_map: a dict that partitions the blocks contained in care_map 173# into smaller domains that are useful for doing diffs on. 174# (Typically a domain is a file, and the key in file_map is the 175# pathname.) 176# 177# ReadRangeSet(): a function that takes a RangeSet and returns the 178# data contained in the image blocks of that RangeSet. The data 179# is returned as a list or tuple of strings; concatenating the 180# elements together should produce the requested data. 181# Implementations are free to break up the data into list/tuple 182# elements in any way that is convenient. 183# 184# TotalSha1(): a function that returns (as a hex string) the SHA-1 185# hash of all the data in the image (ie, all the blocks in the 186# care_map) 187# 188# When creating a BlockImageDiff, the src image may be None, in which 189# case the list of transfers produced will never read from the 190# original image. 191 192class BlockImageDiff(object): 193 def __init__(self, tgt, src=None, threads=None, version=2): 194 if threads is None: 195 threads = multiprocessing.cpu_count() // 2 196 if threads == 0: threads = 1 197 self.threads = threads 198 self.version = version 199 200 assert version in (1, 2) 201 202 self.tgt = tgt 203 if src is None: 204 src = EmptyImage() 205 self.src = src 206 207 # The updater code that installs the patch always uses 4k blocks. 208 assert tgt.blocksize == 4096 209 assert src.blocksize == 4096 210 211 # The range sets in each filemap should comprise a partition of 212 # the care map. 213 self.AssertPartition(src.care_map, src.file_map.values()) 214 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 215 216 def Compute(self, prefix): 217 # When looking for a source file to use as the diff input for a 218 # target file, we try: 219 # 1) an exact path match if available, otherwise 220 # 2) a exact basename match if available, otherwise 221 # 3) a basename match after all runs of digits are replaced by 222 # "#" if available, otherwise 223 # 4) we have no source for this target. 224 self.AbbreviateSourceNames() 225 self.FindTransfers() 226 227 # Find the ordering dependencies among transfers (this is O(n^2) 228 # in the number of transfers). 229 self.GenerateDigraph() 230 # Find a sequence of transfers that satisfies as many ordering 231 # dependencies as possible (heuristically). 232 self.FindVertexSequence() 233 # Fix up the ordering dependencies that the sequence didn't 234 # satisfy. 235 if self.version == 1: 236 self.RemoveBackwardEdges() 237 else: 238 self.ReverseBackwardEdges() 239 self.ImproveVertexSequence() 240 241 # Double-check our work. 242 self.AssertSequenceGood() 243 244 self.ComputePatches(prefix) 245 self.WriteTransfers(prefix) 246 247 def WriteTransfers(self, prefix): 248 out = [] 249 250 out.append("%d\n" % (self.version,)) # format version number 251 total = 0 252 performs_read = False 253 254 stashes = {} 255 stashed_blocks = 0 256 max_stashed_blocks = 0 257 258 free_stash_ids = [] 259 next_stash_id = 0 260 261 for xf in self.transfers: 262 263 if self.version < 2: 264 assert not xf.stash_before 265 assert not xf.use_stash 266 267 for s, sr in xf.stash_before: 268 assert s not in stashes 269 if free_stash_ids: 270 sid = heapq.heappop(free_stash_ids) 271 else: 272 sid = next_stash_id 273 next_stash_id += 1 274 stashes[s] = sid 275 stashed_blocks += sr.size() 276 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 277 278 if stashed_blocks > max_stashed_blocks: 279 max_stashed_blocks = stashed_blocks 280 281 if self.version == 1: 282 src_string = xf.src_ranges.to_string_raw() 283 elif self.version == 2: 284 285 # <# blocks> <src ranges> 286 # OR 287 # <# blocks> <src ranges> <src locs> <stash refs...> 288 # OR 289 # <# blocks> - <stash refs...> 290 291 size = xf.src_ranges.size() 292 src_string = [str(size)] 293 294 unstashed_src_ranges = xf.src_ranges 295 mapped_stashes = [] 296 for s, sr in xf.use_stash: 297 sid = stashes.pop(s) 298 stashed_blocks -= sr.size() 299 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 300 sr = xf.src_ranges.map_within(sr) 301 mapped_stashes.append(sr) 302 src_string.append("%d:%s" % (sid, sr.to_string_raw())) 303 heapq.heappush(free_stash_ids, sid) 304 305 if unstashed_src_ranges: 306 src_string.insert(1, unstashed_src_ranges.to_string_raw()) 307 if xf.use_stash: 308 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 309 src_string.insert(2, mapped_unstashed.to_string_raw()) 310 mapped_stashes.append(mapped_unstashed) 311 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 312 else: 313 src_string.insert(1, "-") 314 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 315 316 src_string = " ".join(src_string) 317 318 # both versions: 319 # zero <rangeset> 320 # new <rangeset> 321 # erase <rangeset> 322 # 323 # version 1: 324 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 325 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 326 # move <src rangeset> <tgt rangeset> 327 # 328 # version 2: 329 # bsdiff patchstart patchlen <tgt rangeset> <src_string> 330 # imgdiff patchstart patchlen <tgt rangeset> <src_string> 331 # move <tgt rangeset> <src_string> 332 333 tgt_size = xf.tgt_ranges.size() 334 335 if xf.style == "new": 336 assert xf.tgt_ranges 337 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 338 total += tgt_size 339 elif xf.style == "move": 340 performs_read = True 341 assert xf.tgt_ranges 342 assert xf.src_ranges.size() == tgt_size 343 if xf.src_ranges != xf.tgt_ranges: 344 if self.version == 1: 345 out.append("%s %s %s\n" % ( 346 xf.style, 347 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 348 elif self.version == 2: 349 out.append("%s %s %s\n" % ( 350 xf.style, 351 xf.tgt_ranges.to_string_raw(), src_string)) 352 total += tgt_size 353 elif xf.style in ("bsdiff", "imgdiff"): 354 performs_read = True 355 assert xf.tgt_ranges 356 assert xf.src_ranges 357 if self.version == 1: 358 out.append("%s %d %d %s %s\n" % ( 359 xf.style, xf.patch_start, xf.patch_len, 360 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 361 elif self.version == 2: 362 out.append("%s %d %d %s %s\n" % ( 363 xf.style, xf.patch_start, xf.patch_len, 364 xf.tgt_ranges.to_string_raw(), src_string)) 365 total += tgt_size 366 elif xf.style == "zero": 367 assert xf.tgt_ranges 368 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 369 if to_zero: 370 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 371 total += to_zero.size() 372 else: 373 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 374 375 out.insert(1, str(total) + "\n") 376 if self.version >= 2: 377 # version 2 only: after the total block count, we give the number 378 # of stash slots needed, and the maximum size needed (in blocks) 379 out.insert(2, str(next_stash_id) + "\n") 380 out.insert(3, str(max_stashed_blocks) + "\n") 381 382 # sanity check: abort if we're going to need more than 512 MB if 383 # stash space 384 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 385 386 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 387 if performs_read: 388 # if some of the original data is used, then at the end we'll 389 # erase all the blocks on the partition that don't contain data 390 # in the new image. 391 new_dontcare = all_tgt.subtract(self.tgt.care_map) 392 if new_dontcare: 393 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 394 else: 395 # if nothing is read (ie, this is a full OTA), then we can start 396 # by erasing the entire partition. 397 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),)) 398 399 with open(prefix + ".transfer.list", "wb") as f: 400 for i in out: 401 f.write(i) 402 403 if self.version >= 2: 404 print("max stashed blocks: %d (%d bytes)\n" % ( 405 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 406 407 def ComputePatches(self, prefix): 408 print("Reticulating splines...") 409 diff_q = [] 410 patch_num = 0 411 with open(prefix + ".new.dat", "wb") as new_f: 412 for xf in self.transfers: 413 if xf.style == "zero": 414 pass 415 elif xf.style == "new": 416 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 417 new_f.write(piece) 418 elif xf.style == "diff": 419 src = self.src.ReadRangeSet(xf.src_ranges) 420 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 421 422 # We can't compare src and tgt directly because they may have 423 # the same content but be broken up into blocks differently, eg: 424 # 425 # ["he", "llo"] vs ["h", "ello"] 426 # 427 # We want those to compare equal, ideally without having to 428 # actually concatenate the strings (these may be tens of 429 # megabytes). 430 431 src_sha1 = sha1() 432 for p in src: 433 src_sha1.update(p) 434 tgt_sha1 = sha1() 435 tgt_size = 0 436 for p in tgt: 437 tgt_sha1.update(p) 438 tgt_size += len(p) 439 440 if src_sha1.digest() == tgt_sha1.digest(): 441 # These are identical; we don't need to generate a patch, 442 # just issue copy commands on the device. 443 xf.style = "move" 444 else: 445 # For files in zip format (eg, APKs, JARs, etc.) we would 446 # like to use imgdiff -z if possible (because it usually 447 # produces significantly smaller patches than bsdiff). 448 # This is permissible if: 449 # 450 # - the source and target files are monotonic (ie, the 451 # data is stored with blocks in increasing order), and 452 # - we haven't removed any blocks from the source set. 453 # 454 # If these conditions are satisfied then appending all the 455 # blocks in the set together in order will produce a valid 456 # zip file (plus possibly extra zeros in the last block), 457 # which is what imgdiff needs to operate. (imgdiff is 458 # fine with extra zeros at the end of the file.) 459 imgdiff = (xf.intact and 460 xf.tgt_name.split(".")[-1].lower() 461 in ("apk", "jar", "zip")) 462 xf.style = "imgdiff" if imgdiff else "bsdiff" 463 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 464 patch_num += 1 465 466 else: 467 assert False, "unknown style " + xf.style 468 469 if diff_q: 470 if self.threads > 1: 471 print("Computing patches (using %d threads)..." % (self.threads,)) 472 else: 473 print("Computing patches...") 474 diff_q.sort() 475 476 patches = [None] * patch_num 477 478 lock = threading.Lock() 479 def diff_worker(): 480 while True: 481 with lock: 482 if not diff_q: return 483 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 484 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 485 size = len(patch) 486 with lock: 487 patches[patchnum] = (patch, xf) 488 print("%10d %10d (%6.2f%%) %7s %s" % ( 489 size, tgt_size, size * 100.0 / tgt_size, xf.style, 490 xf.tgt_name if xf.tgt_name == xf.src_name else ( 491 xf.tgt_name + " (from " + xf.src_name + ")"))) 492 493 threads = [threading.Thread(target=diff_worker) 494 for i in range(self.threads)] 495 for th in threads: 496 th.start() 497 while threads: 498 threads.pop().join() 499 else: 500 patches = [] 501 502 p = 0 503 with open(prefix + ".patch.dat", "wb") as patch_f: 504 for patch, xf in patches: 505 xf.patch_start = p 506 xf.patch_len = len(patch) 507 patch_f.write(patch) 508 p += len(patch) 509 510 def AssertSequenceGood(self): 511 # Simulate the sequences of transfers we will output, and check that: 512 # - we never read a block after writing it, and 513 # - we write every block we care about exactly once. 514 515 # Start with no blocks having been touched yet. 516 touched = RangeSet() 517 518 # Imagine processing the transfers in order. 519 for xf in self.transfers: 520 # Check that the input blocks for this transfer haven't yet been touched. 521 522 x = xf.src_ranges 523 if self.version >= 2: 524 for _, sr in xf.use_stash: 525 x = x.subtract(sr) 526 527 assert not touched.overlaps(x) 528 # Check that the output blocks for this transfer haven't yet been touched. 529 assert not touched.overlaps(xf.tgt_ranges) 530 # Touch all the blocks written by this transfer. 531 touched = touched.union(xf.tgt_ranges) 532 533 # Check that we've written every target block. 534 assert touched == self.tgt.care_map 535 536 def ImproveVertexSequence(self): 537 print("Improving vertex order...") 538 539 # At this point our digraph is acyclic; we reversed any edges that 540 # were backwards in the heuristically-generated sequence. The 541 # previously-generated order is still acceptable, but we hope to 542 # find a better order that needs less memory for stashed data. 543 # Now we do a topological sort to generate a new vertex order, 544 # using a greedy algorithm to choose which vertex goes next 545 # whenever we have a choice. 546 547 # Make a copy of the edge set; this copy will get destroyed by the 548 # algorithm. 549 for xf in self.transfers: 550 xf.incoming = xf.goes_after.copy() 551 xf.outgoing = xf.goes_before.copy() 552 553 L = [] # the new vertex order 554 555 # S is the set of sources in the remaining graph; we always choose 556 # the one that leaves the least amount of stashed data after it's 557 # executed. 558 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 559 if not u.incoming] 560 heapq.heapify(S) 561 562 while S: 563 _, _, xf = heapq.heappop(S) 564 L.append(xf) 565 for u in xf.outgoing: 566 del u.incoming[xf] 567 if not u.incoming: 568 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 569 570 # if this fails then our graph had a cycle. 571 assert len(L) == len(self.transfers) 572 573 self.transfers = L 574 for i, xf in enumerate(L): 575 xf.order = i 576 577 def RemoveBackwardEdges(self): 578 print("Removing backward edges...") 579 in_order = 0 580 out_of_order = 0 581 lost_source = 0 582 583 for xf in self.transfers: 584 lost = 0 585 size = xf.src_ranges.size() 586 for u in xf.goes_before: 587 # xf should go before u 588 if xf.order < u.order: 589 # it does, hurray! 590 in_order += 1 591 else: 592 # it doesn't, boo. trim the blocks that u writes from xf's 593 # source, so that xf can go after u. 594 out_of_order += 1 595 assert xf.src_ranges.overlaps(u.tgt_ranges) 596 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 597 xf.intact = False 598 599 if xf.style == "diff" and not xf.src_ranges: 600 # nothing left to diff from; treat as new data 601 xf.style = "new" 602 603 lost = size - xf.src_ranges.size() 604 lost_source += lost 605 606 print((" %d/%d dependencies (%.2f%%) were violated; " 607 "%d source blocks removed.") % 608 (out_of_order, in_order + out_of_order, 609 (out_of_order * 100.0 / (in_order + out_of_order)) 610 if (in_order + out_of_order) else 0.0, 611 lost_source)) 612 613 def ReverseBackwardEdges(self): 614 print("Reversing backward edges...") 615 in_order = 0 616 out_of_order = 0 617 stashes = 0 618 stash_size = 0 619 620 for xf in self.transfers: 621 lost = 0 622 size = xf.src_ranges.size() 623 for u in xf.goes_before.copy(): 624 # xf should go before u 625 if xf.order < u.order: 626 # it does, hurray! 627 in_order += 1 628 else: 629 # it doesn't, boo. modify u to stash the blocks that it 630 # writes that xf wants to read, and then require u to go 631 # before xf. 632 out_of_order += 1 633 634 overlap = xf.src_ranges.intersect(u.tgt_ranges) 635 assert overlap 636 637 u.stash_before.append((stashes, overlap)) 638 xf.use_stash.append((stashes, overlap)) 639 stashes += 1 640 stash_size += overlap.size() 641 642 # reverse the edge direction; now xf must go after u 643 del xf.goes_before[u] 644 del u.goes_after[xf] 645 xf.goes_after[u] = None # value doesn't matter 646 u.goes_before[xf] = None 647 648 print((" %d/%d dependencies (%.2f%%) were violated; " 649 "%d source blocks stashed.") % 650 (out_of_order, in_order + out_of_order, 651 (out_of_order * 100.0 / (in_order + out_of_order)) 652 if (in_order + out_of_order) else 0.0, 653 stash_size)) 654 655 def FindVertexSequence(self): 656 print("Finding vertex sequence...") 657 658 # This is based on "A Fast & Effective Heuristic for the Feedback 659 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 660 # it as starting with the digraph G and moving all the vertices to 661 # be on a horizontal line in some order, trying to minimize the 662 # number of edges that end up pointing to the left. Left-pointing 663 # edges will get removed to turn the digraph into a DAG. In this 664 # case each edge has a weight which is the number of source blocks 665 # we'll lose if that edge is removed; we try to minimize the total 666 # weight rather than just the number of edges. 667 668 # Make a copy of the edge set; this copy will get destroyed by the 669 # algorithm. 670 for xf in self.transfers: 671 xf.incoming = xf.goes_after.copy() 672 xf.outgoing = xf.goes_before.copy() 673 674 # We use an OrderedDict instead of just a set so that the output 675 # is repeatable; otherwise it would depend on the hash values of 676 # the transfer objects. 677 G = OrderedDict() 678 for xf in self.transfers: 679 G[xf] = None 680 s1 = deque() # the left side of the sequence, built from left to right 681 s2 = deque() # the right side of the sequence, built from right to left 682 683 while G: 684 685 # Put all sinks at the end of the sequence. 686 while True: 687 sinks = [u for u in G if not u.outgoing] 688 if not sinks: break 689 for u in sinks: 690 s2.appendleft(u) 691 del G[u] 692 for iu in u.incoming: 693 del iu.outgoing[u] 694 695 # Put all the sources at the beginning of the sequence. 696 while True: 697 sources = [u for u in G if not u.incoming] 698 if not sources: break 699 for u in sources: 700 s1.append(u) 701 del G[u] 702 for iu in u.outgoing: 703 del iu.incoming[u] 704 705 if not G: break 706 707 # Find the "best" vertex to put next. "Best" is the one that 708 # maximizes the net difference in source blocks saved we get by 709 # pretending it's a source rather than a sink. 710 711 max_d = None 712 best_u = None 713 for u in G: 714 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 715 if best_u is None or d > max_d: 716 max_d = d 717 best_u = u 718 719 u = best_u 720 s1.append(u) 721 del G[u] 722 for iu in u.outgoing: 723 del iu.incoming[u] 724 for iu in u.incoming: 725 del iu.outgoing[u] 726 727 # Now record the sequence in the 'order' field of each transfer, 728 # and by rearranging self.transfers to be in the chosen sequence. 729 730 new_transfers = [] 731 for x in itertools.chain(s1, s2): 732 x.order = len(new_transfers) 733 new_transfers.append(x) 734 del x.incoming 735 del x.outgoing 736 737 self.transfers = new_transfers 738 739 def GenerateDigraph(self): 740 print("Generating digraph...") 741 for a in self.transfers: 742 for b in self.transfers: 743 if a is b: continue 744 745 # If the blocks written by A are read by B, then B needs to go before A. 746 i = a.tgt_ranges.intersect(b.src_ranges) 747 if i: 748 if b.src_name == "__ZERO": 749 # the cost of removing source blocks for the __ZERO domain 750 # is (nearly) zero. 751 size = 0 752 else: 753 size = i.size() 754 b.goes_before[a] = size 755 a.goes_after[b] = size 756 757 def FindTransfers(self): 758 self.transfers = [] 759 empty = RangeSet() 760 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 761 if tgt_fn == "__ZERO": 762 # the special "__ZERO" domain is all the blocks not contained 763 # in any file and that are filled with zeros. We have a 764 # special transfer style for zero blocks. 765 src_ranges = self.src.file_map.get("__ZERO", empty) 766 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 767 "zero", self.transfers) 768 continue 769 770 elif tgt_fn in self.src.file_map: 771 # Look for an exact pathname match in the source. 772 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 773 "diff", self.transfers) 774 continue 775 776 b = os.path.basename(tgt_fn) 777 if b in self.src_basenames: 778 # Look for an exact basename match in the source. 779 src_fn = self.src_basenames[b] 780 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 781 "diff", self.transfers) 782 continue 783 784 b = re.sub("[0-9]+", "#", b) 785 if b in self.src_numpatterns: 786 # Look for a 'number pattern' match (a basename match after 787 # all runs of digits are replaced by "#"). (This is useful 788 # for .so files that contain version numbers in the filename 789 # that get bumped.) 790 src_fn = self.src_numpatterns[b] 791 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 792 "diff", self.transfers) 793 continue 794 795 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 796 797 def AbbreviateSourceNames(self): 798 self.src_basenames = {} 799 self.src_numpatterns = {} 800 801 for k in self.src.file_map.keys(): 802 b = os.path.basename(k) 803 self.src_basenames[b] = k 804 b = re.sub("[0-9]+", "#", b) 805 self.src_numpatterns[b] = k 806 807 @staticmethod 808 def AssertPartition(total, seq): 809 """Assert that all the RangeSets in 'seq' form a partition of the 810 'total' RangeSet (ie, they are nonintersecting and their union 811 equals 'total').""" 812 so_far = RangeSet() 813 for i in seq: 814 assert not so_far.overlaps(i) 815 so_far = so_far.union(i) 816 assert so_far == total 817