blockimgdiff.py revision dd67a295cc91ad636b81b705b5bedda945035568
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import pprint 24import re 25import subprocess 26import sys 27import threading 28import tempfile 29 30from rangelib import * 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72class EmptyImage(object): 73 """A zero-length image.""" 74 blocksize = 4096 75 care_map = RangeSet() 76 total_blocks = 0 77 file_map = {} 78 def ReadRangeSet(self, ranges): 79 return () 80 def TotalSha1(self): 81 return sha1().hexdigest() 82 83 84class DataImage(object): 85 """An image wrapped around a single string of data.""" 86 87 def __init__(self, data, trim=False, pad=False): 88 self.data = data 89 self.blocksize = 4096 90 91 assert not (trim and pad) 92 93 partial = len(self.data) % self.blocksize 94 if partial > 0: 95 if trim: 96 self.data = self.data[:-partial] 97 elif pad: 98 self.data += '\0' * (self.blocksize - partial) 99 else: 100 raise ValueError(("data for DataImage must be multiple of %d bytes " 101 "unless trim or pad is specified") % 102 (self.blocksize,)) 103 104 assert len(self.data) % self.blocksize == 0 105 106 self.total_blocks = len(self.data) / self.blocksize 107 self.care_map = RangeSet(data=(0, self.total_blocks)) 108 109 zero_blocks = [] 110 nonzero_blocks = [] 111 reference = '\0' * self.blocksize 112 113 for i in range(self.total_blocks): 114 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 115 if d == reference: 116 zero_blocks.append(i) 117 zero_blocks.append(i+1) 118 else: 119 nonzero_blocks.append(i) 120 nonzero_blocks.append(i+1) 121 122 self.file_map = {"__ZERO": RangeSet(zero_blocks), 123 "__NONZERO": RangeSet(nonzero_blocks)} 124 125 def ReadRangeSet(self, ranges): 126 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 127 128 def TotalSha1(self): 129 if not hasattr(self, "sha1"): 130 self.sha1 = sha1(self.data).hexdigest() 131 return self.sha1 132 133 134class Transfer(object): 135 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 136 self.tgt_name = tgt_name 137 self.src_name = src_name 138 self.tgt_ranges = tgt_ranges 139 self.src_ranges = src_ranges 140 self.style = style 141 self.intact = (getattr(tgt_ranges, "monotonic", False) and 142 getattr(src_ranges, "monotonic", False)) 143 self.goes_before = {} 144 self.goes_after = {} 145 146 self.stash_before = [] 147 self.use_stash = [] 148 149 self.id = len(by_id) 150 by_id.append(self) 151 152 def NetStashChange(self): 153 return (sum(sr.size() for (_, sr) in self.stash_before) - 154 sum(sr.size() for (_, sr) in self.use_stash)) 155 156 def __str__(self): 157 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 158 " to " + str(self.tgt_ranges) + ">") 159 160 161# BlockImageDiff works on two image objects. An image object is 162# anything that provides the following attributes: 163# 164# blocksize: the size in bytes of a block, currently must be 4096. 165# 166# total_blocks: the total size of the partition/image, in blocks. 167# 168# care_map: a RangeSet containing which blocks (in the range [0, 169# total_blocks) we actually care about; i.e. which blocks contain 170# data. 171# 172# file_map: a dict that partitions the blocks contained in care_map 173# into smaller domains that are useful for doing diffs on. 174# (Typically a domain is a file, and the key in file_map is the 175# pathname.) 176# 177# ReadRangeSet(): a function that takes a RangeSet and returns the 178# data contained in the image blocks of that RangeSet. The data 179# is returned as a list or tuple of strings; concatenating the 180# elements together should produce the requested data. 181# Implementations are free to break up the data into list/tuple 182# elements in any way that is convenient. 183# 184# TotalSha1(): a function that returns (as a hex string) the SHA-1 185# hash of all the data in the image (ie, all the blocks in the 186# care_map) 187# 188# When creating a BlockImageDiff, the src image may be None, in which 189# case the list of transfers produced will never read from the 190# original image. 191 192class BlockImageDiff(object): 193 def __init__(self, tgt, src=None, threads=None, version=3): 194 if threads is None: 195 threads = multiprocessing.cpu_count() // 2 196 if threads == 0: threads = 1 197 self.threads = threads 198 self.version = version 199 200 assert version in (1, 2, 3) 201 202 self.tgt = tgt 203 if src is None: 204 src = EmptyImage() 205 self.src = src 206 207 # The updater code that installs the patch always uses 4k blocks. 208 assert tgt.blocksize == 4096 209 assert src.blocksize == 4096 210 211 # The range sets in each filemap should comprise a partition of 212 # the care map. 213 self.AssertPartition(src.care_map, src.file_map.values()) 214 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 215 216 def Compute(self, prefix): 217 # When looking for a source file to use as the diff input for a 218 # target file, we try: 219 # 1) an exact path match if available, otherwise 220 # 2) a exact basename match if available, otherwise 221 # 3) a basename match after all runs of digits are replaced by 222 # "#" if available, otherwise 223 # 4) we have no source for this target. 224 self.AbbreviateSourceNames() 225 self.FindTransfers() 226 227 # Find the ordering dependencies among transfers (this is O(n^2) 228 # in the number of transfers). 229 self.GenerateDigraph() 230 # Find a sequence of transfers that satisfies as many ordering 231 # dependencies as possible (heuristically). 232 self.FindVertexSequence() 233 # Fix up the ordering dependencies that the sequence didn't 234 # satisfy. 235 if self.version == 1: 236 self.RemoveBackwardEdges() 237 else: 238 self.ReverseBackwardEdges() 239 self.ImproveVertexSequence() 240 241 # Double-check our work. 242 self.AssertSequenceGood() 243 244 self.ComputePatches(prefix) 245 self.WriteTransfers(prefix) 246 247 def HashBlocks(self, source, ranges): 248 data = source.ReadRangeSet(ranges) 249 ctx = sha1() 250 251 for p in data: 252 ctx.update(p) 253 254 return ctx.hexdigest() 255 256 def WriteTransfers(self, prefix): 257 out = [] 258 259 total = 0 260 performs_read = False 261 262 stashes = {} 263 stashed_blocks = 0 264 max_stashed_blocks = 0 265 266 free_stash_ids = [] 267 next_stash_id = 0 268 269 for xf in self.transfers: 270 271 if self.version < 2: 272 assert not xf.stash_before 273 assert not xf.use_stash 274 275 for s, sr in xf.stash_before: 276 assert s not in stashes 277 if free_stash_ids: 278 sid = heapq.heappop(free_stash_ids) 279 else: 280 sid = next_stash_id 281 next_stash_id += 1 282 stashes[s] = sid 283 stashed_blocks += sr.size() 284 if self.version == 2: 285 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 286 else: 287 sh = self.HashBlocks(self.src, sr) 288 if sh in stashes: 289 stashes[sh] += 1 290 else: 291 stashes[sh] = 1 292 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 293 294 if stashed_blocks > max_stashed_blocks: 295 max_stashed_blocks = stashed_blocks 296 297 if self.version == 1: 298 src_string = xf.src_ranges.to_string_raw() 299 elif self.version >= 2: 300 301 # <# blocks> <src ranges> 302 # OR 303 # <# blocks> <src ranges> <src locs> <stash refs...> 304 # OR 305 # <# blocks> - <stash refs...> 306 307 size = xf.src_ranges.size() 308 src_string = [str(size)] 309 free_string = [] 310 311 unstashed_src_ranges = xf.src_ranges 312 mapped_stashes = [] 313 for s, sr in xf.use_stash: 314 sid = stashes.pop(s) 315 stashed_blocks -= sr.size() 316 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 317 sh = self.HashBlocks(self.src, sr) 318 sr = xf.src_ranges.map_within(sr) 319 mapped_stashes.append(sr) 320 if self.version == 2: 321 src_string.append("%d:%s" % (sid, sr.to_string_raw())) 322 else: 323 assert sh in stashes 324 src_string.append("%s:%s" % (sh, sr.to_string_raw())) 325 stashes[sh] -= 1 326 if stashes[sh] == 0: 327 free_string.append("free %s\n" % (sh)) 328 stashes.pop(sh) 329 heapq.heappush(free_stash_ids, sid) 330 331 if unstashed_src_ranges: 332 src_string.insert(1, unstashed_src_ranges.to_string_raw()) 333 if xf.use_stash: 334 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 335 src_string.insert(2, mapped_unstashed.to_string_raw()) 336 mapped_stashes.append(mapped_unstashed) 337 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 338 else: 339 src_string.insert(1, "-") 340 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 341 342 src_string = " ".join(src_string) 343 344 # all versions: 345 # zero <rangeset> 346 # new <rangeset> 347 # erase <rangeset> 348 # 349 # version 1: 350 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 351 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 352 # move <src rangeset> <tgt rangeset> 353 # 354 # version 2: 355 # bsdiff patchstart patchlen <tgt rangeset> <src_string> 356 # imgdiff patchstart patchlen <tgt rangeset> <src_string> 357 # move <tgt rangeset> <src_string> 358 # 359 # version 3: 360 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string> 361 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string> 362 # move hash <tgt rangeset> <src_string> 363 364 tgt_size = xf.tgt_ranges.size() 365 366 if xf.style == "new": 367 assert xf.tgt_ranges 368 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 369 total += tgt_size 370 elif xf.style == "move": 371 performs_read = True 372 assert xf.tgt_ranges 373 assert xf.src_ranges.size() == tgt_size 374 if xf.src_ranges != xf.tgt_ranges: 375 if self.version == 1: 376 out.append("%s %s %s\n" % ( 377 xf.style, 378 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 379 elif self.version == 2: 380 out.append("%s %s %s\n" % ( 381 xf.style, 382 xf.tgt_ranges.to_string_raw(), src_string)) 383 elif self.version >= 3: 384 out.append("%s %s %s %s\n" % ( 385 xf.style, 386 self.HashBlocks(self.tgt, xf.tgt_ranges), 387 xf.tgt_ranges.to_string_raw(), src_string)) 388 total += tgt_size 389 elif xf.style in ("bsdiff", "imgdiff"): 390 performs_read = True 391 assert xf.tgt_ranges 392 assert xf.src_ranges 393 if self.version == 1: 394 out.append("%s %d %d %s %s\n" % ( 395 xf.style, xf.patch_start, xf.patch_len, 396 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 397 elif self.version == 2: 398 out.append("%s %d %d %s %s\n" % ( 399 xf.style, xf.patch_start, xf.patch_len, 400 xf.tgt_ranges.to_string_raw(), src_string)) 401 elif self.version >= 3: 402 out.append("%s %d %d %s %s %s %s\n" % ( 403 xf.style, 404 xf.patch_start, xf.patch_len, 405 self.HashBlocks(self.src, xf.src_ranges), 406 self.HashBlocks(self.tgt, xf.tgt_ranges), 407 xf.tgt_ranges.to_string_raw(), src_string)) 408 total += tgt_size 409 elif xf.style == "zero": 410 assert xf.tgt_ranges 411 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 412 if to_zero: 413 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 414 total += to_zero.size() 415 else: 416 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 417 418 if free_string: 419 out.append("".join(free_string)) 420 421 422 # sanity check: abort if we're going to need more than 512 MB if 423 # stash space 424 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 425 426 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 427 if performs_read: 428 # if some of the original data is used, then at the end we'll 429 # erase all the blocks on the partition that don't contain data 430 # in the new image. 431 new_dontcare = all_tgt.subtract(self.tgt.care_map) 432 if new_dontcare: 433 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 434 else: 435 # if nothing is read (ie, this is a full OTA), then we can start 436 # by erasing the entire partition. 437 out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),)) 438 439 out.insert(0, "%d\n" % (self.version,)) # format version number 440 out.insert(1, str(total) + "\n") 441 if self.version >= 2: 442 # version 2 only: after the total block count, we give the number 443 # of stash slots needed, and the maximum size needed (in blocks) 444 out.insert(2, str(next_stash_id) + "\n") 445 out.insert(3, str(max_stashed_blocks) + "\n") 446 447 with open(prefix + ".transfer.list", "wb") as f: 448 for i in out: 449 f.write(i) 450 451 if self.version >= 2: 452 print("max stashed blocks: %d (%d bytes)\n" % ( 453 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 454 455 def ComputePatches(self, prefix): 456 print("Reticulating splines...") 457 diff_q = [] 458 patch_num = 0 459 with open(prefix + ".new.dat", "wb") as new_f: 460 for xf in self.transfers: 461 if xf.style == "zero": 462 pass 463 elif xf.style == "new": 464 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 465 new_f.write(piece) 466 elif xf.style == "diff": 467 src = self.src.ReadRangeSet(xf.src_ranges) 468 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 469 470 # We can't compare src and tgt directly because they may have 471 # the same content but be broken up into blocks differently, eg: 472 # 473 # ["he", "llo"] vs ["h", "ello"] 474 # 475 # We want those to compare equal, ideally without having to 476 # actually concatenate the strings (these may be tens of 477 # megabytes). 478 479 src_sha1 = sha1() 480 for p in src: 481 src_sha1.update(p) 482 tgt_sha1 = sha1() 483 tgt_size = 0 484 for p in tgt: 485 tgt_sha1.update(p) 486 tgt_size += len(p) 487 488 if src_sha1.digest() == tgt_sha1.digest(): 489 # These are identical; we don't need to generate a patch, 490 # just issue copy commands on the device. 491 xf.style = "move" 492 else: 493 # For files in zip format (eg, APKs, JARs, etc.) we would 494 # like to use imgdiff -z if possible (because it usually 495 # produces significantly smaller patches than bsdiff). 496 # This is permissible if: 497 # 498 # - the source and target files are monotonic (ie, the 499 # data is stored with blocks in increasing order), and 500 # - we haven't removed any blocks from the source set. 501 # 502 # If these conditions are satisfied then appending all the 503 # blocks in the set together in order will produce a valid 504 # zip file (plus possibly extra zeros in the last block), 505 # which is what imgdiff needs to operate. (imgdiff is 506 # fine with extra zeros at the end of the file.) 507 imgdiff = (xf.intact and 508 xf.tgt_name.split(".")[-1].lower() 509 in ("apk", "jar", "zip")) 510 xf.style = "imgdiff" if imgdiff else "bsdiff" 511 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 512 patch_num += 1 513 514 else: 515 assert False, "unknown style " + xf.style 516 517 if diff_q: 518 if self.threads > 1: 519 print("Computing patches (using %d threads)..." % (self.threads,)) 520 else: 521 print("Computing patches...") 522 diff_q.sort() 523 524 patches = [None] * patch_num 525 526 lock = threading.Lock() 527 def diff_worker(): 528 while True: 529 with lock: 530 if not diff_q: return 531 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 532 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 533 size = len(patch) 534 with lock: 535 patches[patchnum] = (patch, xf) 536 print("%10d %10d (%6.2f%%) %7s %s" % ( 537 size, tgt_size, size * 100.0 / tgt_size, xf.style, 538 xf.tgt_name if xf.tgt_name == xf.src_name else ( 539 xf.tgt_name + " (from " + xf.src_name + ")"))) 540 541 threads = [threading.Thread(target=diff_worker) 542 for i in range(self.threads)] 543 for th in threads: 544 th.start() 545 while threads: 546 threads.pop().join() 547 else: 548 patches = [] 549 550 p = 0 551 with open(prefix + ".patch.dat", "wb") as patch_f: 552 for patch, xf in patches: 553 xf.patch_start = p 554 xf.patch_len = len(patch) 555 patch_f.write(patch) 556 p += len(patch) 557 558 def AssertSequenceGood(self): 559 # Simulate the sequences of transfers we will output, and check that: 560 # - we never read a block after writing it, and 561 # - we write every block we care about exactly once. 562 563 # Start with no blocks having been touched yet. 564 touched = RangeSet() 565 566 # Imagine processing the transfers in order. 567 for xf in self.transfers: 568 # Check that the input blocks for this transfer haven't yet been touched. 569 570 x = xf.src_ranges 571 if self.version >= 2: 572 for _, sr in xf.use_stash: 573 x = x.subtract(sr) 574 575 assert not touched.overlaps(x) 576 # Check that the output blocks for this transfer haven't yet been touched. 577 assert not touched.overlaps(xf.tgt_ranges) 578 # Touch all the blocks written by this transfer. 579 touched = touched.union(xf.tgt_ranges) 580 581 # Check that we've written every target block. 582 assert touched == self.tgt.care_map 583 584 def ImproveVertexSequence(self): 585 print("Improving vertex order...") 586 587 # At this point our digraph is acyclic; we reversed any edges that 588 # were backwards in the heuristically-generated sequence. The 589 # previously-generated order is still acceptable, but we hope to 590 # find a better order that needs less memory for stashed data. 591 # Now we do a topological sort to generate a new vertex order, 592 # using a greedy algorithm to choose which vertex goes next 593 # whenever we have a choice. 594 595 # Make a copy of the edge set; this copy will get destroyed by the 596 # algorithm. 597 for xf in self.transfers: 598 xf.incoming = xf.goes_after.copy() 599 xf.outgoing = xf.goes_before.copy() 600 601 L = [] # the new vertex order 602 603 # S is the set of sources in the remaining graph; we always choose 604 # the one that leaves the least amount of stashed data after it's 605 # executed. 606 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 607 if not u.incoming] 608 heapq.heapify(S) 609 610 while S: 611 _, _, xf = heapq.heappop(S) 612 L.append(xf) 613 for u in xf.outgoing: 614 del u.incoming[xf] 615 if not u.incoming: 616 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 617 618 # if this fails then our graph had a cycle. 619 assert len(L) == len(self.transfers) 620 621 self.transfers = L 622 for i, xf in enumerate(L): 623 xf.order = i 624 625 def RemoveBackwardEdges(self): 626 print("Removing backward edges...") 627 in_order = 0 628 out_of_order = 0 629 lost_source = 0 630 631 for xf in self.transfers: 632 lost = 0 633 size = xf.src_ranges.size() 634 for u in xf.goes_before: 635 # xf should go before u 636 if xf.order < u.order: 637 # it does, hurray! 638 in_order += 1 639 else: 640 # it doesn't, boo. trim the blocks that u writes from xf's 641 # source, so that xf can go after u. 642 out_of_order += 1 643 assert xf.src_ranges.overlaps(u.tgt_ranges) 644 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 645 xf.intact = False 646 647 if xf.style == "diff" and not xf.src_ranges: 648 # nothing left to diff from; treat as new data 649 xf.style = "new" 650 651 lost = size - xf.src_ranges.size() 652 lost_source += lost 653 654 print((" %d/%d dependencies (%.2f%%) were violated; " 655 "%d source blocks removed.") % 656 (out_of_order, in_order + out_of_order, 657 (out_of_order * 100.0 / (in_order + out_of_order)) 658 if (in_order + out_of_order) else 0.0, 659 lost_source)) 660 661 def ReverseBackwardEdges(self): 662 print("Reversing backward edges...") 663 in_order = 0 664 out_of_order = 0 665 stashes = 0 666 stash_size = 0 667 668 for xf in self.transfers: 669 lost = 0 670 size = xf.src_ranges.size() 671 for u in xf.goes_before.copy(): 672 # xf should go before u 673 if xf.order < u.order: 674 # it does, hurray! 675 in_order += 1 676 else: 677 # it doesn't, boo. modify u to stash the blocks that it 678 # writes that xf wants to read, and then require u to go 679 # before xf. 680 out_of_order += 1 681 682 overlap = xf.src_ranges.intersect(u.tgt_ranges) 683 assert overlap 684 685 u.stash_before.append((stashes, overlap)) 686 xf.use_stash.append((stashes, overlap)) 687 stashes += 1 688 stash_size += overlap.size() 689 690 # reverse the edge direction; now xf must go after u 691 del xf.goes_before[u] 692 del u.goes_after[xf] 693 xf.goes_after[u] = None # value doesn't matter 694 u.goes_before[xf] = None 695 696 print((" %d/%d dependencies (%.2f%%) were violated; " 697 "%d source blocks stashed.") % 698 (out_of_order, in_order + out_of_order, 699 (out_of_order * 100.0 / (in_order + out_of_order)) 700 if (in_order + out_of_order) else 0.0, 701 stash_size)) 702 703 def FindVertexSequence(self): 704 print("Finding vertex sequence...") 705 706 # This is based on "A Fast & Effective Heuristic for the Feedback 707 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 708 # it as starting with the digraph G and moving all the vertices to 709 # be on a horizontal line in some order, trying to minimize the 710 # number of edges that end up pointing to the left. Left-pointing 711 # edges will get removed to turn the digraph into a DAG. In this 712 # case each edge has a weight which is the number of source blocks 713 # we'll lose if that edge is removed; we try to minimize the total 714 # weight rather than just the number of edges. 715 716 # Make a copy of the edge set; this copy will get destroyed by the 717 # algorithm. 718 for xf in self.transfers: 719 xf.incoming = xf.goes_after.copy() 720 xf.outgoing = xf.goes_before.copy() 721 722 # We use an OrderedDict instead of just a set so that the output 723 # is repeatable; otherwise it would depend on the hash values of 724 # the transfer objects. 725 G = OrderedDict() 726 for xf in self.transfers: 727 G[xf] = None 728 s1 = deque() # the left side of the sequence, built from left to right 729 s2 = deque() # the right side of the sequence, built from right to left 730 731 while G: 732 733 # Put all sinks at the end of the sequence. 734 while True: 735 sinks = [u for u in G if not u.outgoing] 736 if not sinks: break 737 for u in sinks: 738 s2.appendleft(u) 739 del G[u] 740 for iu in u.incoming: 741 del iu.outgoing[u] 742 743 # Put all the sources at the beginning of the sequence. 744 while True: 745 sources = [u for u in G if not u.incoming] 746 if not sources: break 747 for u in sources: 748 s1.append(u) 749 del G[u] 750 for iu in u.outgoing: 751 del iu.incoming[u] 752 753 if not G: break 754 755 # Find the "best" vertex to put next. "Best" is the one that 756 # maximizes the net difference in source blocks saved we get by 757 # pretending it's a source rather than a sink. 758 759 max_d = None 760 best_u = None 761 for u in G: 762 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 763 if best_u is None or d > max_d: 764 max_d = d 765 best_u = u 766 767 u = best_u 768 s1.append(u) 769 del G[u] 770 for iu in u.outgoing: 771 del iu.incoming[u] 772 for iu in u.incoming: 773 del iu.outgoing[u] 774 775 # Now record the sequence in the 'order' field of each transfer, 776 # and by rearranging self.transfers to be in the chosen sequence. 777 778 new_transfers = [] 779 for x in itertools.chain(s1, s2): 780 x.order = len(new_transfers) 781 new_transfers.append(x) 782 del x.incoming 783 del x.outgoing 784 785 self.transfers = new_transfers 786 787 def GenerateDigraph(self): 788 print("Generating digraph...") 789 for a in self.transfers: 790 for b in self.transfers: 791 if a is b: continue 792 793 # If the blocks written by A are read by B, then B needs to go before A. 794 i = a.tgt_ranges.intersect(b.src_ranges) 795 if i: 796 if b.src_name == "__ZERO": 797 # the cost of removing source blocks for the __ZERO domain 798 # is (nearly) zero. 799 size = 0 800 else: 801 size = i.size() 802 b.goes_before[a] = size 803 a.goes_after[b] = size 804 805 def FindTransfers(self): 806 self.transfers = [] 807 empty = RangeSet() 808 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 809 if tgt_fn == "__ZERO": 810 # the special "__ZERO" domain is all the blocks not contained 811 # in any file and that are filled with zeros. We have a 812 # special transfer style for zero blocks. 813 src_ranges = self.src.file_map.get("__ZERO", empty) 814 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 815 "zero", self.transfers) 816 continue 817 818 elif tgt_fn in self.src.file_map: 819 # Look for an exact pathname match in the source. 820 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 821 "diff", self.transfers) 822 continue 823 824 b = os.path.basename(tgt_fn) 825 if b in self.src_basenames: 826 # Look for an exact basename match in the source. 827 src_fn = self.src_basenames[b] 828 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 829 "diff", self.transfers) 830 continue 831 832 b = re.sub("[0-9]+", "#", b) 833 if b in self.src_numpatterns: 834 # Look for a 'number pattern' match (a basename match after 835 # all runs of digits are replaced by "#"). (This is useful 836 # for .so files that contain version numbers in the filename 837 # that get bumped.) 838 src_fn = self.src_numpatterns[b] 839 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 840 "diff", self.transfers) 841 continue 842 843 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 844 845 def AbbreviateSourceNames(self): 846 self.src_basenames = {} 847 self.src_numpatterns = {} 848 849 for k in self.src.file_map.keys(): 850 b = os.path.basename(k) 851 self.src_basenames[b] = k 852 b = re.sub("[0-9]+", "#", b) 853 self.src_numpatterns[b] = k 854 855 @staticmethod 856 def AssertPartition(total, seq): 857 """Assert that all the RangeSets in 'seq' form a partition of the 858 'total' RangeSet (ie, they are nonintersecting and their union 859 equals 'total').""" 860 so_far = RangeSet() 861 for i in seq: 862 assert not so_far.overlaps(i) 863 so_far = so_far.union(i) 864 assert so_far == total 865