blockimgdiff.py revision 7b985f6aedad882c7b3332032bd9265ec42fc871
1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import pprint 24import re 25import subprocess 26import sys 27import threading 28import tempfile 29 30from rangelib import * 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72class EmptyImage(object): 73 """A zero-length image.""" 74 blocksize = 4096 75 care_map = RangeSet() 76 total_blocks = 0 77 file_map = {} 78 def ReadRangeSet(self, ranges): 79 return () 80 def TotalSha1(self): 81 return sha1().hexdigest() 82 83 84class DataImage(object): 85 """An image wrapped around a single string of data.""" 86 87 def __init__(self, data, trim=False, pad=False): 88 self.data = data 89 self.blocksize = 4096 90 91 assert not (trim and pad) 92 93 partial = len(self.data) % self.blocksize 94 if partial > 0: 95 if trim: 96 self.data = self.data[:-partial] 97 elif pad: 98 self.data += '\0' * (self.blocksize - partial) 99 else: 100 raise ValueError(("data for DataImage must be multiple of %d bytes " 101 "unless trim or pad is specified") % 102 (self.blocksize,)) 103 104 assert len(self.data) % self.blocksize == 0 105 106 self.total_blocks = len(self.data) / self.blocksize 107 self.care_map = RangeSet(data=(0, self.total_blocks)) 108 109 zero_blocks = [] 110 nonzero_blocks = [] 111 reference = '\0' * self.blocksize 112 113 for i in range(self.total_blocks): 114 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 115 if d == reference: 116 zero_blocks.append(i) 117 zero_blocks.append(i+1) 118 else: 119 nonzero_blocks.append(i) 120 nonzero_blocks.append(i+1) 121 122 self.file_map = {"__ZERO": RangeSet(zero_blocks), 123 "__NONZERO": RangeSet(nonzero_blocks)} 124 125 def ReadRangeSet(self, ranges): 126 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 127 128 def TotalSha1(self): 129 if not hasattr(self, "sha1"): 130 self.sha1 = sha1(self.data).hexdigest() 131 return self.sha1 132 133 134class Transfer(object): 135 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 136 self.tgt_name = tgt_name 137 self.src_name = src_name 138 self.tgt_ranges = tgt_ranges 139 self.src_ranges = src_ranges 140 self.style = style 141 self.intact = (getattr(tgt_ranges, "monotonic", False) and 142 getattr(src_ranges, "monotonic", False)) 143 self.goes_before = {} 144 self.goes_after = {} 145 146 self.stash_before = [] 147 self.use_stash = [] 148 149 self.id = len(by_id) 150 by_id.append(self) 151 152 def NetStashChange(self): 153 return (sum(sr.size() for (_, sr) in self.stash_before) - 154 sum(sr.size() for (_, sr) in self.use_stash)) 155 156 def __str__(self): 157 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 158 " to " + str(self.tgt_ranges) + ">") 159 160 161# BlockImageDiff works on two image objects. An image object is 162# anything that provides the following attributes: 163# 164# blocksize: the size in bytes of a block, currently must be 4096. 165# 166# total_blocks: the total size of the partition/image, in blocks. 167# 168# care_map: a RangeSet containing which blocks (in the range [0, 169# total_blocks) we actually care about; i.e. which blocks contain 170# data. 171# 172# file_map: a dict that partitions the blocks contained in care_map 173# into smaller domains that are useful for doing diffs on. 174# (Typically a domain is a file, and the key in file_map is the 175# pathname.) 176# 177# ReadRangeSet(): a function that takes a RangeSet and returns the 178# data contained in the image blocks of that RangeSet. The data 179# is returned as a list or tuple of strings; concatenating the 180# elements together should produce the requested data. 181# Implementations are free to break up the data into list/tuple 182# elements in any way that is convenient. 183# 184# TotalSha1(): a function that returns (as a hex string) the SHA-1 185# hash of all the data in the image (ie, all the blocks in the 186# care_map) 187# 188# When creating a BlockImageDiff, the src image may be None, in which 189# case the list of transfers produced will never read from the 190# original image. 191 192class BlockImageDiff(object): 193 def __init__(self, tgt, src=None, threads=None, version=3): 194 if threads is None: 195 threads = multiprocessing.cpu_count() // 2 196 if threads == 0: threads = 1 197 self.threads = threads 198 self.version = version 199 200 assert version in (1, 2, 3) 201 202 self.tgt = tgt 203 if src is None: 204 src = EmptyImage() 205 self.src = src 206 207 # The updater code that installs the patch always uses 4k blocks. 208 assert tgt.blocksize == 4096 209 assert src.blocksize == 4096 210 211 # The range sets in each filemap should comprise a partition of 212 # the care map. 213 self.AssertPartition(src.care_map, src.file_map.values()) 214 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 215 216 def Compute(self, prefix): 217 # When looking for a source file to use as the diff input for a 218 # target file, we try: 219 # 1) an exact path match if available, otherwise 220 # 2) a exact basename match if available, otherwise 221 # 3) a basename match after all runs of digits are replaced by 222 # "#" if available, otherwise 223 # 4) we have no source for this target. 224 self.AbbreviateSourceNames() 225 self.FindTransfers() 226 227 # Find the ordering dependencies among transfers (this is O(n^2) 228 # in the number of transfers). 229 self.GenerateDigraph() 230 # Find a sequence of transfers that satisfies as many ordering 231 # dependencies as possible (heuristically). 232 self.FindVertexSequence() 233 # Fix up the ordering dependencies that the sequence didn't 234 # satisfy. 235 if self.version == 1: 236 self.RemoveBackwardEdges() 237 else: 238 self.ReverseBackwardEdges() 239 self.ImproveVertexSequence() 240 241 # Double-check our work. 242 self.AssertSequenceGood() 243 244 self.ComputePatches(prefix) 245 self.WriteTransfers(prefix) 246 247 def HashBlocks(self, source, ranges): 248 data = source.ReadRangeSet(ranges) 249 ctx = sha1() 250 251 for p in data: 252 ctx.update(p) 253 254 return ctx.hexdigest() 255 256 def WriteTransfers(self, prefix): 257 out = [] 258 259 total = 0 260 performs_read = False 261 262 stashes = {} 263 stashed_blocks = 0 264 max_stashed_blocks = 0 265 266 free_stash_ids = [] 267 next_stash_id = 0 268 269 for xf in self.transfers: 270 271 if self.version < 2: 272 assert not xf.stash_before 273 assert not xf.use_stash 274 275 for s, sr in xf.stash_before: 276 assert s not in stashes 277 if free_stash_ids: 278 sid = heapq.heappop(free_stash_ids) 279 else: 280 sid = next_stash_id 281 next_stash_id += 1 282 stashes[s] = sid 283 stashed_blocks += sr.size() 284 if self.version == 2: 285 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 286 else: 287 sh = self.HashBlocks(self.src, sr) 288 if sh in stashes: 289 stashes[sh] += 1 290 else: 291 stashes[sh] = 1 292 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 293 294 if stashed_blocks > max_stashed_blocks: 295 max_stashed_blocks = stashed_blocks 296 297 free_string = [] 298 299 if self.version == 1: 300 src_string = xf.src_ranges.to_string_raw() 301 elif self.version >= 2: 302 303 # <# blocks> <src ranges> 304 # OR 305 # <# blocks> <src ranges> <src locs> <stash refs...> 306 # OR 307 # <# blocks> - <stash refs...> 308 309 size = xf.src_ranges.size() 310 src_string = [str(size)] 311 312 unstashed_src_ranges = xf.src_ranges 313 mapped_stashes = [] 314 for s, sr in xf.use_stash: 315 sid = stashes.pop(s) 316 stashed_blocks -= sr.size() 317 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 318 sh = self.HashBlocks(self.src, sr) 319 sr = xf.src_ranges.map_within(sr) 320 mapped_stashes.append(sr) 321 if self.version == 2: 322 src_string.append("%d:%s" % (sid, sr.to_string_raw())) 323 else: 324 assert sh in stashes 325 src_string.append("%s:%s" % (sh, sr.to_string_raw())) 326 stashes[sh] -= 1 327 if stashes[sh] == 0: 328 free_string.append("free %s\n" % (sh)) 329 stashes.pop(sh) 330 heapq.heappush(free_stash_ids, sid) 331 332 if unstashed_src_ranges: 333 src_string.insert(1, unstashed_src_ranges.to_string_raw()) 334 if xf.use_stash: 335 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 336 src_string.insert(2, mapped_unstashed.to_string_raw()) 337 mapped_stashes.append(mapped_unstashed) 338 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 339 else: 340 src_string.insert(1, "-") 341 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 342 343 src_string = " ".join(src_string) 344 345 # all versions: 346 # zero <rangeset> 347 # new <rangeset> 348 # erase <rangeset> 349 # 350 # version 1: 351 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 352 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 353 # move <src rangeset> <tgt rangeset> 354 # 355 # version 2: 356 # bsdiff patchstart patchlen <tgt rangeset> <src_string> 357 # imgdiff patchstart patchlen <tgt rangeset> <src_string> 358 # move <tgt rangeset> <src_string> 359 # 360 # version 3: 361 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string> 362 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_string> 363 # move hash <tgt rangeset> <src_string> 364 365 tgt_size = xf.tgt_ranges.size() 366 367 if xf.style == "new": 368 assert xf.tgt_ranges 369 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 370 total += tgt_size 371 elif xf.style == "move": 372 performs_read = True 373 assert xf.tgt_ranges 374 assert xf.src_ranges.size() == tgt_size 375 if xf.src_ranges != xf.tgt_ranges: 376 if self.version == 1: 377 out.append("%s %s %s\n" % ( 378 xf.style, 379 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 380 elif self.version == 2: 381 out.append("%s %s %s\n" % ( 382 xf.style, 383 xf.tgt_ranges.to_string_raw(), src_string)) 384 elif self.version >= 3: 385 out.append("%s %s %s %s\n" % ( 386 xf.style, 387 self.HashBlocks(self.tgt, xf.tgt_ranges), 388 xf.tgt_ranges.to_string_raw(), src_string)) 389 total += tgt_size 390 elif xf.style in ("bsdiff", "imgdiff"): 391 performs_read = True 392 assert xf.tgt_ranges 393 assert xf.src_ranges 394 if self.version == 1: 395 out.append("%s %d %d %s %s\n" % ( 396 xf.style, xf.patch_start, xf.patch_len, 397 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 398 elif self.version == 2: 399 out.append("%s %d %d %s %s\n" % ( 400 xf.style, xf.patch_start, xf.patch_len, 401 xf.tgt_ranges.to_string_raw(), src_string)) 402 elif self.version >= 3: 403 out.append("%s %d %d %s %s %s %s\n" % ( 404 xf.style, 405 xf.patch_start, xf.patch_len, 406 self.HashBlocks(self.src, xf.src_ranges), 407 self.HashBlocks(self.tgt, xf.tgt_ranges), 408 xf.tgt_ranges.to_string_raw(), src_string)) 409 total += tgt_size 410 elif xf.style == "zero": 411 assert xf.tgt_ranges 412 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 413 if to_zero: 414 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 415 total += to_zero.size() 416 else: 417 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 418 419 if free_string: 420 out.append("".join(free_string)) 421 422 423 # sanity check: abort if we're going to need more than 512 MB if 424 # stash space 425 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 426 427 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 428 if performs_read: 429 # if some of the original data is used, then at the end we'll 430 # erase all the blocks on the partition that don't contain data 431 # in the new image. 432 new_dontcare = all_tgt.subtract(self.tgt.care_map) 433 if new_dontcare: 434 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 435 else: 436 # if nothing is read (ie, this is a full OTA), then we can start 437 # by erasing the entire partition. 438 out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),)) 439 440 out.insert(0, "%d\n" % (self.version,)) # format version number 441 out.insert(1, str(total) + "\n") 442 if self.version >= 2: 443 # version 2 only: after the total block count, we give the number 444 # of stash slots needed, and the maximum size needed (in blocks) 445 out.insert(2, str(next_stash_id) + "\n") 446 out.insert(3, str(max_stashed_blocks) + "\n") 447 448 with open(prefix + ".transfer.list", "wb") as f: 449 for i in out: 450 f.write(i) 451 452 if self.version >= 2: 453 print("max stashed blocks: %d (%d bytes)\n" % ( 454 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 455 456 def ComputePatches(self, prefix): 457 print("Reticulating splines...") 458 diff_q = [] 459 patch_num = 0 460 with open(prefix + ".new.dat", "wb") as new_f: 461 for xf in self.transfers: 462 if xf.style == "zero": 463 pass 464 elif xf.style == "new": 465 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 466 new_f.write(piece) 467 elif xf.style == "diff": 468 src = self.src.ReadRangeSet(xf.src_ranges) 469 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 470 471 # We can't compare src and tgt directly because they may have 472 # the same content but be broken up into blocks differently, eg: 473 # 474 # ["he", "llo"] vs ["h", "ello"] 475 # 476 # We want those to compare equal, ideally without having to 477 # actually concatenate the strings (these may be tens of 478 # megabytes). 479 480 src_sha1 = sha1() 481 for p in src: 482 src_sha1.update(p) 483 tgt_sha1 = sha1() 484 tgt_size = 0 485 for p in tgt: 486 tgt_sha1.update(p) 487 tgt_size += len(p) 488 489 if src_sha1.digest() == tgt_sha1.digest(): 490 # These are identical; we don't need to generate a patch, 491 # just issue copy commands on the device. 492 xf.style = "move" 493 else: 494 # For files in zip format (eg, APKs, JARs, etc.) we would 495 # like to use imgdiff -z if possible (because it usually 496 # produces significantly smaller patches than bsdiff). 497 # This is permissible if: 498 # 499 # - the source and target files are monotonic (ie, the 500 # data is stored with blocks in increasing order), and 501 # - we haven't removed any blocks from the source set. 502 # 503 # If these conditions are satisfied then appending all the 504 # blocks in the set together in order will produce a valid 505 # zip file (plus possibly extra zeros in the last block), 506 # which is what imgdiff needs to operate. (imgdiff is 507 # fine with extra zeros at the end of the file.) 508 imgdiff = (xf.intact and 509 xf.tgt_name.split(".")[-1].lower() 510 in ("apk", "jar", "zip")) 511 xf.style = "imgdiff" if imgdiff else "bsdiff" 512 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 513 patch_num += 1 514 515 else: 516 assert False, "unknown style " + xf.style 517 518 if diff_q: 519 if self.threads > 1: 520 print("Computing patches (using %d threads)..." % (self.threads,)) 521 else: 522 print("Computing patches...") 523 diff_q.sort() 524 525 patches = [None] * patch_num 526 527 lock = threading.Lock() 528 def diff_worker(): 529 while True: 530 with lock: 531 if not diff_q: return 532 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 533 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 534 size = len(patch) 535 with lock: 536 patches[patchnum] = (patch, xf) 537 print("%10d %10d (%6.2f%%) %7s %s" % ( 538 size, tgt_size, size * 100.0 / tgt_size, xf.style, 539 xf.tgt_name if xf.tgt_name == xf.src_name else ( 540 xf.tgt_name + " (from " + xf.src_name + ")"))) 541 542 threads = [threading.Thread(target=diff_worker) 543 for i in range(self.threads)] 544 for th in threads: 545 th.start() 546 while threads: 547 threads.pop().join() 548 else: 549 patches = [] 550 551 p = 0 552 with open(prefix + ".patch.dat", "wb") as patch_f: 553 for patch, xf in patches: 554 xf.patch_start = p 555 xf.patch_len = len(patch) 556 patch_f.write(patch) 557 p += len(patch) 558 559 def AssertSequenceGood(self): 560 # Simulate the sequences of transfers we will output, and check that: 561 # - we never read a block after writing it, and 562 # - we write every block we care about exactly once. 563 564 # Start with no blocks having been touched yet. 565 touched = RangeSet() 566 567 # Imagine processing the transfers in order. 568 for xf in self.transfers: 569 # Check that the input blocks for this transfer haven't yet been touched. 570 571 x = xf.src_ranges 572 if self.version >= 2: 573 for _, sr in xf.use_stash: 574 x = x.subtract(sr) 575 576 assert not touched.overlaps(x) 577 # Check that the output blocks for this transfer haven't yet been touched. 578 assert not touched.overlaps(xf.tgt_ranges) 579 # Touch all the blocks written by this transfer. 580 touched = touched.union(xf.tgt_ranges) 581 582 # Check that we've written every target block. 583 assert touched == self.tgt.care_map 584 585 def ImproveVertexSequence(self): 586 print("Improving vertex order...") 587 588 # At this point our digraph is acyclic; we reversed any edges that 589 # were backwards in the heuristically-generated sequence. The 590 # previously-generated order is still acceptable, but we hope to 591 # find a better order that needs less memory for stashed data. 592 # Now we do a topological sort to generate a new vertex order, 593 # using a greedy algorithm to choose which vertex goes next 594 # whenever we have a choice. 595 596 # Make a copy of the edge set; this copy will get destroyed by the 597 # algorithm. 598 for xf in self.transfers: 599 xf.incoming = xf.goes_after.copy() 600 xf.outgoing = xf.goes_before.copy() 601 602 L = [] # the new vertex order 603 604 # S is the set of sources in the remaining graph; we always choose 605 # the one that leaves the least amount of stashed data after it's 606 # executed. 607 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 608 if not u.incoming] 609 heapq.heapify(S) 610 611 while S: 612 _, _, xf = heapq.heappop(S) 613 L.append(xf) 614 for u in xf.outgoing: 615 del u.incoming[xf] 616 if not u.incoming: 617 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 618 619 # if this fails then our graph had a cycle. 620 assert len(L) == len(self.transfers) 621 622 self.transfers = L 623 for i, xf in enumerate(L): 624 xf.order = i 625 626 def RemoveBackwardEdges(self): 627 print("Removing backward edges...") 628 in_order = 0 629 out_of_order = 0 630 lost_source = 0 631 632 for xf in self.transfers: 633 lost = 0 634 size = xf.src_ranges.size() 635 for u in xf.goes_before: 636 # xf should go before u 637 if xf.order < u.order: 638 # it does, hurray! 639 in_order += 1 640 else: 641 # it doesn't, boo. trim the blocks that u writes from xf's 642 # source, so that xf can go after u. 643 out_of_order += 1 644 assert xf.src_ranges.overlaps(u.tgt_ranges) 645 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 646 xf.intact = False 647 648 if xf.style == "diff" and not xf.src_ranges: 649 # nothing left to diff from; treat as new data 650 xf.style = "new" 651 652 lost = size - xf.src_ranges.size() 653 lost_source += lost 654 655 print((" %d/%d dependencies (%.2f%%) were violated; " 656 "%d source blocks removed.") % 657 (out_of_order, in_order + out_of_order, 658 (out_of_order * 100.0 / (in_order + out_of_order)) 659 if (in_order + out_of_order) else 0.0, 660 lost_source)) 661 662 def ReverseBackwardEdges(self): 663 print("Reversing backward edges...") 664 in_order = 0 665 out_of_order = 0 666 stashes = 0 667 stash_size = 0 668 669 for xf in self.transfers: 670 lost = 0 671 size = xf.src_ranges.size() 672 for u in xf.goes_before.copy(): 673 # xf should go before u 674 if xf.order < u.order: 675 # it does, hurray! 676 in_order += 1 677 else: 678 # it doesn't, boo. modify u to stash the blocks that it 679 # writes that xf wants to read, and then require u to go 680 # before xf. 681 out_of_order += 1 682 683 overlap = xf.src_ranges.intersect(u.tgt_ranges) 684 assert overlap 685 686 u.stash_before.append((stashes, overlap)) 687 xf.use_stash.append((stashes, overlap)) 688 stashes += 1 689 stash_size += overlap.size() 690 691 # reverse the edge direction; now xf must go after u 692 del xf.goes_before[u] 693 del u.goes_after[xf] 694 xf.goes_after[u] = None # value doesn't matter 695 u.goes_before[xf] = None 696 697 print((" %d/%d dependencies (%.2f%%) were violated; " 698 "%d source blocks stashed.") % 699 (out_of_order, in_order + out_of_order, 700 (out_of_order * 100.0 / (in_order + out_of_order)) 701 if (in_order + out_of_order) else 0.0, 702 stash_size)) 703 704 def FindVertexSequence(self): 705 print("Finding vertex sequence...") 706 707 # This is based on "A Fast & Effective Heuristic for the Feedback 708 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 709 # it as starting with the digraph G and moving all the vertices to 710 # be on a horizontal line in some order, trying to minimize the 711 # number of edges that end up pointing to the left. Left-pointing 712 # edges will get removed to turn the digraph into a DAG. In this 713 # case each edge has a weight which is the number of source blocks 714 # we'll lose if that edge is removed; we try to minimize the total 715 # weight rather than just the number of edges. 716 717 # Make a copy of the edge set; this copy will get destroyed by the 718 # algorithm. 719 for xf in self.transfers: 720 xf.incoming = xf.goes_after.copy() 721 xf.outgoing = xf.goes_before.copy() 722 723 # We use an OrderedDict instead of just a set so that the output 724 # is repeatable; otherwise it would depend on the hash values of 725 # the transfer objects. 726 G = OrderedDict() 727 for xf in self.transfers: 728 G[xf] = None 729 s1 = deque() # the left side of the sequence, built from left to right 730 s2 = deque() # the right side of the sequence, built from right to left 731 732 while G: 733 734 # Put all sinks at the end of the sequence. 735 while True: 736 sinks = [u for u in G if not u.outgoing] 737 if not sinks: break 738 for u in sinks: 739 s2.appendleft(u) 740 del G[u] 741 for iu in u.incoming: 742 del iu.outgoing[u] 743 744 # Put all the sources at the beginning of the sequence. 745 while True: 746 sources = [u for u in G if not u.incoming] 747 if not sources: break 748 for u in sources: 749 s1.append(u) 750 del G[u] 751 for iu in u.outgoing: 752 del iu.incoming[u] 753 754 if not G: break 755 756 # Find the "best" vertex to put next. "Best" is the one that 757 # maximizes the net difference in source blocks saved we get by 758 # pretending it's a source rather than a sink. 759 760 max_d = None 761 best_u = None 762 for u in G: 763 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 764 if best_u is None or d > max_d: 765 max_d = d 766 best_u = u 767 768 u = best_u 769 s1.append(u) 770 del G[u] 771 for iu in u.outgoing: 772 del iu.incoming[u] 773 for iu in u.incoming: 774 del iu.outgoing[u] 775 776 # Now record the sequence in the 'order' field of each transfer, 777 # and by rearranging self.transfers to be in the chosen sequence. 778 779 new_transfers = [] 780 for x in itertools.chain(s1, s2): 781 x.order = len(new_transfers) 782 new_transfers.append(x) 783 del x.incoming 784 del x.outgoing 785 786 self.transfers = new_transfers 787 788 def GenerateDigraph(self): 789 print("Generating digraph...") 790 for a in self.transfers: 791 for b in self.transfers: 792 if a is b: continue 793 794 # If the blocks written by A are read by B, then B needs to go before A. 795 i = a.tgt_ranges.intersect(b.src_ranges) 796 if i: 797 if b.src_name == "__ZERO": 798 # the cost of removing source blocks for the __ZERO domain 799 # is (nearly) zero. 800 size = 0 801 else: 802 size = i.size() 803 b.goes_before[a] = size 804 a.goes_after[b] = size 805 806 def FindTransfers(self): 807 self.transfers = [] 808 empty = RangeSet() 809 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 810 if tgt_fn == "__ZERO": 811 # the special "__ZERO" domain is all the blocks not contained 812 # in any file and that are filled with zeros. We have a 813 # special transfer style for zero blocks. 814 src_ranges = self.src.file_map.get("__ZERO", empty) 815 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 816 "zero", self.transfers) 817 continue 818 819 elif tgt_fn in self.src.file_map: 820 # Look for an exact pathname match in the source. 821 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 822 "diff", self.transfers) 823 continue 824 825 b = os.path.basename(tgt_fn) 826 if b in self.src_basenames: 827 # Look for an exact basename match in the source. 828 src_fn = self.src_basenames[b] 829 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 830 "diff", self.transfers) 831 continue 832 833 b = re.sub("[0-9]+", "#", b) 834 if b in self.src_numpatterns: 835 # Look for a 'number pattern' match (a basename match after 836 # all runs of digits are replaced by "#"). (This is useful 837 # for .so files that contain version numbers in the filename 838 # that get bumped.) 839 src_fn = self.src_numpatterns[b] 840 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 841 "diff", self.transfers) 842 continue 843 844 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 845 846 def AbbreviateSourceNames(self): 847 self.src_basenames = {} 848 self.src_numpatterns = {} 849 850 for k in self.src.file_map.keys(): 851 b = os.path.basename(k) 852 self.src_basenames[b] = k 853 b = re.sub("[0-9]+", "#", b) 854 self.src_numpatterns[b] = k 855 856 @staticmethod 857 def AssertPartition(total, seq): 858 """Assert that all the RangeSets in 'seq' form a partition of the 859 'total' RangeSet (ie, they are nonintersecting and their union 860 equals 'total').""" 861 so_far = RangeSet() 862 for i in seq: 863 assert not so_far.overlaps(i) 864 so_far = so_far.union(i) 865 assert so_far == total 866