1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Applying a Chrome OS update payload.
6
7This module is used internally by the main Payload class for applying an update
8payload. The interface for invoking the applier is as follows:
9
10  applier = PayloadApplier(payload)
11  applier.Run(...)
12
13"""
14
15from __future__ import print_function
16
17import array
18import bz2
19import hashlib
20import itertools
21import os
22import shutil
23import subprocess
24import sys
25import tempfile
26
27import common
28from error import PayloadError
29
30
31#
32# Helper functions.
33#
34def _VerifySha256(file_obj, expected_hash, name, length=-1):
35  """Verifies the SHA256 hash of a file.
36
37  Args:
38    file_obj: file object to read
39    expected_hash: the hash digest we expect to be getting
40    name: name string of this hash, for error reporting
41    length: precise length of data to verify (optional)
42
43  Raises:
44    PayloadError if computed hash doesn't match expected one, or if fails to
45    read the specified length of data.
46  """
47  # pylint: disable=E1101
48  hasher = hashlib.sha256()
49  block_length = 1024 * 1024
50  max_length = length if length >= 0 else sys.maxint
51
52  while max_length > 0:
53    read_length = min(max_length, block_length)
54    data = file_obj.read(read_length)
55    if not data:
56      break
57    max_length -= len(data)
58    hasher.update(data)
59
60  if length >= 0 and max_length > 0:
61    raise PayloadError(
62        'insufficient data (%d instead of %d) when verifying %s' %
63        (length - max_length, length, name))
64
65  actual_hash = hasher.digest()
66  if actual_hash != expected_hash:
67    raise PayloadError('%s hash (%s) not as expected (%s)' %
68                       (name, common.FormatSha256(actual_hash),
69                        common.FormatSha256(expected_hash)))
70
71
72def _ReadExtents(file_obj, extents, block_size, max_length=-1):
73  """Reads data from file as defined by extent sequence.
74
75  This tries to be efficient by not copying data as it is read in chunks.
76
77  Args:
78    file_obj: file object
79    extents: sequence of block extents (offset and length)
80    block_size: size of each block
81    max_length: maximum length to read (optional)
82
83  Returns:
84    A character array containing the concatenated read data.
85  """
86  data = array.array('c')
87  if max_length < 0:
88    max_length = sys.maxint
89  for ex in extents:
90    if max_length == 0:
91      break
92    read_length = min(max_length, ex.num_blocks * block_size)
93
94    # Fill with zeros or read from file, depending on the type of extent.
95    if ex.start_block == common.PSEUDO_EXTENT_MARKER:
96      data.extend(itertools.repeat('\0', read_length))
97    else:
98      file_obj.seek(ex.start_block * block_size)
99      data.fromfile(file_obj, read_length)
100
101    max_length -= read_length
102
103  return data
104
105
106def _WriteExtents(file_obj, data, extents, block_size, base_name):
107  """Writes data to file as defined by extent sequence.
108
109  This tries to be efficient by not copy data as it is written in chunks.
110
111  Args:
112    file_obj: file object
113    data: data to write
114    extents: sequence of block extents (offset and length)
115    block_size: size of each block
116    base_name: name string of extent sequence for error reporting
117
118  Raises:
119    PayloadError when things don't add up.
120  """
121  data_offset = 0
122  data_length = len(data)
123  for ex, ex_name in common.ExtentIter(extents, base_name):
124    if not data_length:
125      raise PayloadError('%s: more write extents than data' % ex_name)
126    write_length = min(data_length, ex.num_blocks * block_size)
127
128    # Only do actual writing if this is not a pseudo-extent.
129    if ex.start_block != common.PSEUDO_EXTENT_MARKER:
130      file_obj.seek(ex.start_block * block_size)
131      data_view = buffer(data, data_offset, write_length)
132      file_obj.write(data_view)
133
134    data_offset += write_length
135    data_length -= write_length
136
137  if data_length:
138    raise PayloadError('%s: more data than write extents' % base_name)
139
140
141def _ExtentsToBspatchArg(extents, block_size, base_name, data_length=-1):
142  """Translates an extent sequence into a bspatch-compatible string argument.
143
144  Args:
145    extents: sequence of block extents (offset and length)
146    block_size: size of each block
147    base_name: name string of extent sequence for error reporting
148    data_length: the actual total length of the data in bytes (optional)
149
150  Returns:
151    A tuple consisting of (i) a string of the form
152    "off_1:len_1,...,off_n:len_n", (ii) an offset where zero padding is needed
153    for filling the last extent, (iii) the length of the padding (zero means no
154    padding is needed and the extents cover the full length of data).
155
156  Raises:
157    PayloadError if data_length is too short or too long.
158  """
159  arg = ''
160  pad_off = pad_len = 0
161  if data_length < 0:
162    data_length = sys.maxint
163  for ex, ex_name in common.ExtentIter(extents, base_name):
164    if not data_length:
165      raise PayloadError('%s: more extents than total data length' % ex_name)
166
167    is_pseudo = ex.start_block == common.PSEUDO_EXTENT_MARKER
168    start_byte = -1 if is_pseudo else ex.start_block * block_size
169    num_bytes = ex.num_blocks * block_size
170    if data_length < num_bytes:
171      # We're only padding a real extent.
172      if not is_pseudo:
173        pad_off = start_byte + data_length
174        pad_len = num_bytes - data_length
175
176      num_bytes = data_length
177
178    arg += '%s%d:%d' % (arg and ',', start_byte, num_bytes)
179    data_length -= num_bytes
180
181  if data_length:
182    raise PayloadError('%s: extents not covering full data length' % base_name)
183
184  return arg, pad_off, pad_len
185
186
187#
188# Payload application.
189#
190class PayloadApplier(object):
191  """Applying an update payload.
192
193  This is a short-lived object whose purpose is to isolate the logic used for
194  applying an update payload.
195  """
196
197  def __init__(self, payload, bsdiff_in_place=True, bspatch_path=None,
198               imgpatch_path=None, truncate_to_expected_size=True):
199    """Initialize the applier.
200
201    Args:
202      payload: the payload object to check
203      bsdiff_in_place: whether to perform BSDIFF operation in-place (optional)
204      bspatch_path: path to the bspatch binary (optional)
205      imgpatch_path: path to the imgpatch binary (optional)
206      truncate_to_expected_size: whether to truncate the resulting partitions
207                                 to their expected sizes, as specified in the
208                                 payload (optional)
209    """
210    assert payload.is_init, 'uninitialized update payload'
211    self.payload = payload
212    self.block_size = payload.manifest.block_size
213    self.minor_version = payload.manifest.minor_version
214    self.bsdiff_in_place = bsdiff_in_place
215    self.bspatch_path = bspatch_path or 'bspatch'
216    self.imgpatch_path = imgpatch_path or 'imgpatch'
217    self.truncate_to_expected_size = truncate_to_expected_size
218
219  def _ApplyReplaceOperation(self, op, op_name, out_data, part_file, part_size):
220    """Applies a REPLACE{,_BZ} operation.
221
222    Args:
223      op: the operation object
224      op_name: name string for error reporting
225      out_data: the data to be written
226      part_file: the partition file object
227      part_size: the size of the partition
228
229    Raises:
230      PayloadError if something goes wrong.
231    """
232    block_size = self.block_size
233    data_length = len(out_data)
234
235    # Decompress data if needed.
236    if op.type == common.OpType.REPLACE_BZ:
237      out_data = bz2.decompress(out_data)
238      data_length = len(out_data)
239
240    # Write data to blocks specified in dst extents.
241    data_start = 0
242    for ex, ex_name in common.ExtentIter(op.dst_extents,
243                                         '%s.dst_extents' % op_name):
244      start_block = ex.start_block
245      num_blocks = ex.num_blocks
246      count = num_blocks * block_size
247
248      # Make sure it's not a fake (signature) operation.
249      if start_block != common.PSEUDO_EXTENT_MARKER:
250        data_end = data_start + count
251
252        # Make sure we're not running past partition boundary.
253        if (start_block + num_blocks) * block_size > part_size:
254          raise PayloadError(
255              '%s: extent (%s) exceeds partition size (%d)' %
256              (ex_name, common.FormatExtent(ex, block_size),
257               part_size))
258
259        # Make sure that we have enough data to write.
260        if data_end >= data_length + block_size:
261          raise PayloadError(
262              '%s: more dst blocks than data (even with padding)')
263
264        # Pad with zeros if necessary.
265        if data_end > data_length:
266          padding = data_end - data_length
267          out_data += '\0' * padding
268
269        self.payload.payload_file.seek(start_block * block_size)
270        part_file.seek(start_block * block_size)
271        part_file.write(out_data[data_start:data_end])
272
273      data_start += count
274
275    # Make sure we wrote all data.
276    if data_start < data_length:
277      raise PayloadError('%s: wrote fewer bytes (%d) than expected (%d)' %
278                         (op_name, data_start, data_length))
279
280  def _ApplyMoveOperation(self, op, op_name, part_file):
281    """Applies a MOVE operation.
282
283    Note that this operation must read the whole block data from the input and
284    only then dump it, due to our in-place update semantics; otherwise, it
285    might clobber data midway through.
286
287    Args:
288      op: the operation object
289      op_name: name string for error reporting
290      part_file: the partition file object
291
292    Raises:
293      PayloadError if something goes wrong.
294    """
295    block_size = self.block_size
296
297    # Gather input raw data from src extents.
298    in_data = _ReadExtents(part_file, op.src_extents, block_size)
299
300    # Dump extracted data to dst extents.
301    _WriteExtents(part_file, in_data, op.dst_extents, block_size,
302                  '%s.dst_extents' % op_name)
303
304  def _ApplyBsdiffOperation(self, op, op_name, patch_data, new_part_file):
305    """Applies a BSDIFF operation.
306
307    Args:
308      op: the operation object
309      op_name: name string for error reporting
310      patch_data: the binary patch content
311      new_part_file: the target partition file object
312
313    Raises:
314      PayloadError if something goes wrong.
315    """
316    # Implemented using a SOURCE_BSDIFF operation with the source and target
317    # partition set to the new partition.
318    self._ApplyDiffOperation(op, op_name, patch_data, new_part_file,
319                             new_part_file)
320
321  def _ApplySourceCopyOperation(self, op, op_name, old_part_file,
322                                new_part_file):
323    """Applies a SOURCE_COPY operation.
324
325    Args:
326      op: the operation object
327      op_name: name string for error reporting
328      old_part_file: the old partition file object
329      new_part_file: the new partition file object
330
331    Raises:
332      PayloadError if something goes wrong.
333    """
334    if not old_part_file:
335      raise PayloadError(
336          '%s: no source partition file provided for operation type (%d)' %
337          (op_name, op.type))
338
339    block_size = self.block_size
340
341    # Gather input raw data from src extents.
342    in_data = _ReadExtents(old_part_file, op.src_extents, block_size)
343
344    # Dump extracted data to dst extents.
345    _WriteExtents(new_part_file, in_data, op.dst_extents, block_size,
346                  '%s.dst_extents' % op_name)
347
348  def _ApplyDiffOperation(self, op, op_name, patch_data, old_part_file,
349                          new_part_file):
350    """Applies a SOURCE_BSDIFF or IMGDIFF operation.
351
352    Args:
353      op: the operation object
354      op_name: name string for error reporting
355      patch_data: the binary patch content
356      old_part_file: the source partition file object
357      new_part_file: the target partition file object
358
359    Raises:
360      PayloadError if something goes wrong.
361    """
362    if not old_part_file:
363      raise PayloadError(
364          '%s: no source partition file provided for operation type (%d)' %
365          (op_name, op.type))
366
367    block_size = self.block_size
368
369    # Dump patch data to file.
370    with tempfile.NamedTemporaryFile(delete=False) as patch_file:
371      patch_file_name = patch_file.name
372      patch_file.write(patch_data)
373
374    if (hasattr(new_part_file, 'fileno') and
375        ((not old_part_file) or hasattr(old_part_file, 'fileno')) and
376        op.type != common.OpType.IMGDIFF):
377      # Construct input and output extents argument for bspatch.
378      in_extents_arg, _, _ = _ExtentsToBspatchArg(
379          op.src_extents, block_size, '%s.src_extents' % op_name,
380          data_length=op.src_length)
381      out_extents_arg, pad_off, pad_len = _ExtentsToBspatchArg(
382          op.dst_extents, block_size, '%s.dst_extents' % op_name,
383          data_length=op.dst_length)
384
385      new_file_name = '/dev/fd/%d' % new_part_file.fileno()
386      # Diff from source partition.
387      old_file_name = '/dev/fd/%d' % old_part_file.fileno()
388
389      # Invoke bspatch on partition file with extents args.
390      bspatch_cmd = [self.bspatch_path, old_file_name, new_file_name,
391                     patch_file_name, in_extents_arg, out_extents_arg]
392      subprocess.check_call(bspatch_cmd)
393
394      # Pad with zeros past the total output length.
395      if pad_len:
396        new_part_file.seek(pad_off)
397        new_part_file.write('\0' * pad_len)
398    else:
399      # Gather input raw data and write to a temp file.
400      input_part_file = old_part_file if old_part_file else new_part_file
401      in_data = _ReadExtents(input_part_file, op.src_extents, block_size,
402                             max_length=op.src_length)
403      with tempfile.NamedTemporaryFile(delete=False) as in_file:
404        in_file_name = in_file.name
405        in_file.write(in_data)
406
407      # Allocate temporary output file.
408      with tempfile.NamedTemporaryFile(delete=False) as out_file:
409        out_file_name = out_file.name
410
411      # Invoke bspatch.
412      patch_cmd = [self.bspatch_path, in_file_name, out_file_name,
413                   patch_file_name]
414      if op.type == common.OpType.IMGDIFF:
415        patch_cmd[0] = self.imgpatch_path
416      subprocess.check_call(patch_cmd)
417
418      # Read output.
419      with open(out_file_name, 'rb') as out_file:
420        out_data = out_file.read()
421        if len(out_data) != op.dst_length:
422          raise PayloadError(
423              '%s: actual patched data length (%d) not as expected (%d)' %
424              (op_name, len(out_data), op.dst_length))
425
426      # Write output back to partition, with padding.
427      unaligned_out_len = len(out_data) % block_size
428      if unaligned_out_len:
429        out_data += '\0' * (block_size - unaligned_out_len)
430      _WriteExtents(new_part_file, out_data, op.dst_extents, block_size,
431                    '%s.dst_extents' % op_name)
432
433      # Delete input/output files.
434      os.remove(in_file_name)
435      os.remove(out_file_name)
436
437    # Delete patch file.
438    os.remove(patch_file_name)
439
440  def _ApplyOperations(self, operations, base_name, old_part_file,
441                       new_part_file, part_size):
442    """Applies a sequence of update operations to a partition.
443
444    This assumes an in-place update semantics for MOVE and BSDIFF, namely all
445    reads are performed first, then the data is processed and written back to
446    the same file.
447
448    Args:
449      operations: the sequence of operations
450      base_name: the name of the operation sequence
451      old_part_file: the old partition file object, open for reading/writing
452      new_part_file: the new partition file object, open for reading/writing
453      part_size: the partition size
454
455    Raises:
456      PayloadError if anything goes wrong while processing the payload.
457    """
458    for op, op_name in common.OperationIter(operations, base_name):
459      # Read data blob.
460      data = self.payload.ReadDataBlob(op.data_offset, op.data_length)
461
462      if op.type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
463        self._ApplyReplaceOperation(op, op_name, data, new_part_file, part_size)
464      elif op.type == common.OpType.MOVE:
465        self._ApplyMoveOperation(op, op_name, new_part_file)
466      elif op.type == common.OpType.BSDIFF:
467        self._ApplyBsdiffOperation(op, op_name, data, new_part_file)
468      elif op.type == common.OpType.SOURCE_COPY:
469        self._ApplySourceCopyOperation(op, op_name, old_part_file,
470                                       new_part_file)
471      elif op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.IMGDIFF):
472        self._ApplyDiffOperation(op, op_name, data, old_part_file,
473                                 new_part_file)
474      else:
475        raise PayloadError('%s: unknown operation type (%d)' %
476                           (op_name, op.type))
477
478  def _ApplyToPartition(self, operations, part_name, base_name,
479                        new_part_file_name, new_part_info,
480                        old_part_file_name=None, old_part_info=None):
481    """Applies an update to a partition.
482
483    Args:
484      operations: the sequence of update operations to apply
485      part_name: the name of the partition, for error reporting
486      base_name: the name of the operation sequence
487      new_part_file_name: file name to write partition data to
488      new_part_info: size and expected hash of dest partition
489      old_part_file_name: file name of source partition (optional)
490      old_part_info: size and expected hash of source partition (optional)
491
492    Raises:
493      PayloadError if anything goes wrong with the update.
494    """
495    # Do we have a source partition?
496    if old_part_file_name:
497      # Verify the source partition.
498      with open(old_part_file_name, 'rb') as old_part_file:
499        _VerifySha256(old_part_file, old_part_info.hash,
500                      'old ' + part_name, length=old_part_info.size)
501      new_part_file_mode = 'r+b'
502      if self.minor_version == common.INPLACE_MINOR_PAYLOAD_VERSION:
503        # Copy the src partition to the dst one; make sure we don't truncate it.
504        shutil.copyfile(old_part_file_name, new_part_file_name)
505      elif (self.minor_version == common.SOURCE_MINOR_PAYLOAD_VERSION or
506            self.minor_version == common.OPSRCHASH_MINOR_PAYLOAD_VERSION or
507            self.minor_version == common.IMGDIFF_MINOR_PAYLOAD_VERSION):
508        # In minor version >= 2, we don't want to copy the partitions, so
509        # instead just make the new partition file.
510        open(new_part_file_name, 'w').close()
511      else:
512        raise PayloadError("Unknown minor version: %d" % self.minor_version)
513    else:
514      # We need to create/truncate the dst partition file.
515      new_part_file_mode = 'w+b'
516
517    # Apply operations.
518    with open(new_part_file_name, new_part_file_mode) as new_part_file:
519      old_part_file = (open(old_part_file_name, 'r+b')
520                       if old_part_file_name else None)
521      try:
522        self._ApplyOperations(operations, base_name, old_part_file,
523                              new_part_file, new_part_info.size)
524      finally:
525        if old_part_file:
526          old_part_file.close()
527
528      # Truncate the result, if so instructed.
529      if self.truncate_to_expected_size:
530        new_part_file.seek(0, 2)
531        if new_part_file.tell() > new_part_info.size:
532          new_part_file.seek(new_part_info.size)
533          new_part_file.truncate()
534
535    # Verify the resulting partition.
536    with open(new_part_file_name, 'rb') as new_part_file:
537      _VerifySha256(new_part_file, new_part_info.hash,
538                    'new ' + part_name, length=new_part_info.size)
539
540  def Run(self, new_kernel_part, new_rootfs_part, old_kernel_part=None,
541          old_rootfs_part=None):
542    """Applier entry point, invoking all update operations.
543
544    Args:
545      new_kernel_part: name of dest kernel partition file
546      new_rootfs_part: name of dest rootfs partition file
547      old_kernel_part: name of source kernel partition file (optional)
548      old_rootfs_part: name of source rootfs partition file (optional)
549
550    Raises:
551      PayloadError if payload application failed.
552    """
553    self.payload.ResetFile()
554
555    # Make sure the arguments are sane and match the payload.
556    if not (new_kernel_part and new_rootfs_part):
557      raise PayloadError('missing dst {kernel,rootfs} partitions')
558
559    if not (old_kernel_part or old_rootfs_part):
560      if not self.payload.IsFull():
561        raise PayloadError('trying to apply a non-full update without src '
562                           '{kernel,rootfs} partitions')
563    elif old_kernel_part and old_rootfs_part:
564      if not self.payload.IsDelta():
565        raise PayloadError('trying to apply a non-delta update onto src '
566                           '{kernel,rootfs} partitions')
567    else:
568      raise PayloadError('not all src partitions provided')
569
570    # Apply update to rootfs.
571    self._ApplyToPartition(
572        self.payload.manifest.install_operations, 'rootfs',
573        'install_operations', new_rootfs_part,
574        self.payload.manifest.new_rootfs_info, old_rootfs_part,
575        self.payload.manifest.old_rootfs_info)
576
577    # Apply update to kernel update.
578    self._ApplyToPartition(
579        self.payload.manifest.kernel_install_operations, 'kernel',
580        'kernel_install_operations', new_kernel_part,
581        self.payload.manifest.new_kernel_info, old_kernel_part,
582        self.payload.manifest.old_kernel_info)
583