1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Utilities for unit testing."""
6
7from __future__ import print_function
8
9import cStringIO
10import hashlib
11import os
12import struct
13import subprocess
14
15import common
16import payload
17import update_metadata_pb2
18
19
20class TestError(Exception):
21  """An error during testing of update payload code."""
22
23
24# Private/public RSA keys used for testing.
25_PRIVKEY_FILE_NAME = os.path.join(os.path.dirname(__file__),
26                                  'payload-test-key.pem')
27_PUBKEY_FILE_NAME = os.path.join(os.path.dirname(__file__),
28                                 'payload-test-key.pub')
29
30
31def KiB(count):
32  return count << 10
33
34
35def MiB(count):
36  return count << 20
37
38
39def GiB(count):
40  return count << 30
41
42
43def _WriteInt(file_obj, size, is_unsigned, val):
44  """Writes a binary-encoded integer to a file.
45
46  It will do the correct conversion based on the reported size and whether or
47  not a signed number is expected. Assumes a network (big-endian) byte
48  ordering.
49
50  Args:
51    file_obj: a file object
52    size: the integer size in bytes (2, 4 or 8)
53    is_unsigned: whether it is signed or not
54    val: integer value to encode
55
56  Raises:
57    PayloadError if a write error occurred.
58  """
59  try:
60    file_obj.write(struct.pack(common.IntPackingFmtStr(size, is_unsigned), val))
61  except IOError, e:
62    raise payload.PayloadError('error writing to file (%s): %s' %
63                               (file_obj.name, e))
64
65
66def _SetMsgField(msg, field_name, val):
67  """Sets or clears a field in a protobuf message."""
68  if val is None:
69    msg.ClearField(field_name)
70  else:
71    setattr(msg, field_name, val)
72
73
74def SignSha256(data, privkey_file_name):
75  """Signs the data's SHA256 hash with an RSA private key.
76
77  Args:
78    data: the data whose SHA256 hash we want to sign
79    privkey_file_name: private key used for signing data
80
81  Returns:
82    The signature string, prepended with an ASN1 header.
83
84  Raises:
85    TestError if something goes wrong.
86  """
87  # pylint: disable=E1101
88  data_sha256_hash = common.SIG_ASN1_HEADER + hashlib.sha256(data).digest()
89  sign_cmd = ['openssl', 'rsautl', '-sign', '-inkey', privkey_file_name]
90  try:
91    sign_process = subprocess.Popen(sign_cmd, stdin=subprocess.PIPE,
92                                    stdout=subprocess.PIPE)
93    sig, _ = sign_process.communicate(input=data_sha256_hash)
94  except Exception as e:
95    raise TestError('signing subprocess failed: %s' % e)
96
97  return sig
98
99
100class SignaturesGenerator(object):
101  """Generates a payload signatures data block."""
102
103  def __init__(self):
104    self.sigs = update_metadata_pb2.Signatures()
105
106  def AddSig(self, version, data):
107    """Adds a signature to the signature sequence.
108
109    Args:
110      version: signature version (None means do not assign)
111      data: signature binary data (None means do not assign)
112    """
113    # Pylint fails to identify a member of the Signatures message.
114    # pylint: disable=E1101
115    sig = self.sigs.signatures.add()
116    if version is not None:
117      sig.version = version
118    if data is not None:
119      sig.data = data
120
121  def ToBinary(self):
122    """Returns the binary representation of the signature block."""
123    return self.sigs.SerializeToString()
124
125
126class PayloadGenerator(object):
127  """Generates an update payload allowing low-level control.
128
129  Attributes:
130    manifest: the protobuf containing the payload manifest
131    version: the payload version identifier
132    block_size: the block size pertaining to update operations
133
134  """
135
136  def __init__(self, version=1):
137    self.manifest = update_metadata_pb2.DeltaArchiveManifest()
138    self.version = version
139    self.block_size = 0
140
141  @staticmethod
142  def _WriteExtent(ex, val):
143    """Returns an Extent message."""
144    start_block, num_blocks = val
145    _SetMsgField(ex, 'start_block', start_block)
146    _SetMsgField(ex, 'num_blocks', num_blocks)
147
148  @staticmethod
149  def _AddValuesToRepeatedField(repeated_field, values, write_func):
150    """Adds values to a repeated message field."""
151    if values:
152      for val in values:
153        new_item = repeated_field.add()
154        write_func(new_item, val)
155
156  @staticmethod
157  def _AddExtents(extents_field, values):
158    """Adds extents to an extents field."""
159    PayloadGenerator._AddValuesToRepeatedField(
160        extents_field, values, PayloadGenerator._WriteExtent)
161
162  def SetBlockSize(self, block_size):
163    """Sets the payload's block size."""
164    self.block_size = block_size
165    _SetMsgField(self.manifest, 'block_size', block_size)
166
167  def SetPartInfo(self, is_kernel, is_new, part_size, part_hash):
168    """Set the partition info entry.
169
170    Args:
171      is_kernel: whether this is kernel partition info
172      is_new: whether to set old (False) or new (True) info
173      part_size: the partition size (in fact, filesystem size)
174      part_hash: the partition hash
175    """
176    if is_kernel:
177      # pylint: disable=E1101
178      part_info = (self.manifest.new_kernel_info if is_new
179                   else self.manifest.old_kernel_info)
180    else:
181      # pylint: disable=E1101
182      part_info = (self.manifest.new_rootfs_info if is_new
183                   else self.manifest.old_rootfs_info)
184    _SetMsgField(part_info, 'size', part_size)
185    _SetMsgField(part_info, 'hash', part_hash)
186
187  def AddOperation(self, is_kernel, op_type, data_offset=None,
188                   data_length=None, src_extents=None, src_length=None,
189                   dst_extents=None, dst_length=None, data_sha256_hash=None):
190    """Adds an InstallOperation entry."""
191    # pylint: disable=E1101
192    operations = (self.manifest.kernel_install_operations if is_kernel
193                  else self.manifest.install_operations)
194
195    op = operations.add()
196    op.type = op_type
197
198    _SetMsgField(op, 'data_offset', data_offset)
199    _SetMsgField(op, 'data_length', data_length)
200
201    self._AddExtents(op.src_extents, src_extents)
202    _SetMsgField(op, 'src_length', src_length)
203
204    self._AddExtents(op.dst_extents, dst_extents)
205    _SetMsgField(op, 'dst_length', dst_length)
206
207    _SetMsgField(op, 'data_sha256_hash', data_sha256_hash)
208
209  def SetSignatures(self, sigs_offset, sigs_size):
210    """Set the payload's signature block descriptors."""
211    _SetMsgField(self.manifest, 'signatures_offset', sigs_offset)
212    _SetMsgField(self.manifest, 'signatures_size', sigs_size)
213
214  def SetMinorVersion(self, minor_version):
215    """Set the payload's minor version field."""
216    _SetMsgField(self.manifest, 'minor_version', minor_version)
217
218  def _WriteHeaderToFile(self, file_obj, manifest_len):
219    """Writes a payload heaer to a file."""
220    # We need to access protected members in Payload for writing the header.
221    # pylint: disable=W0212
222    file_obj.write(payload.Payload._PayloadHeader._MAGIC)
223    _WriteInt(file_obj, payload.Payload._PayloadHeader._VERSION_SIZE, True,
224              self.version)
225    _WriteInt(file_obj, payload.Payload._PayloadHeader._MANIFEST_LEN_SIZE, True,
226              manifest_len)
227
228  def WriteToFile(self, file_obj, manifest_len=-1, data_blobs=None,
229                  sigs_data=None, padding=None):
230    """Writes the payload content to a file.
231
232    Args:
233      file_obj: a file object open for writing
234      manifest_len: manifest len to dump (otherwise computed automatically)
235      data_blobs: a list of data blobs to be concatenated to the payload
236      sigs_data: a binary Signatures message to be concatenated to the payload
237      padding: stuff to dump past the normal data blobs provided (optional)
238    """
239    manifest = self.manifest.SerializeToString()
240    if manifest_len < 0:
241      manifest_len = len(manifest)
242    self._WriteHeaderToFile(file_obj, manifest_len)
243    file_obj.write(manifest)
244    if data_blobs:
245      for data_blob in data_blobs:
246        file_obj.write(data_blob)
247    if sigs_data:
248      file_obj.write(sigs_data)
249    if padding:
250      file_obj.write(padding)
251
252
253class EnhancedPayloadGenerator(PayloadGenerator):
254  """Payload generator with automatic handling of data blobs.
255
256  Attributes:
257    data_blobs: a list of blobs, in the order they were added
258    curr_offset: the currently consumed offset of blobs added to the payload
259  """
260
261  def __init__(self):
262    super(EnhancedPayloadGenerator, self).__init__()
263    self.data_blobs = []
264    self.curr_offset = 0
265
266  def AddData(self, data_blob):
267    """Adds a (possibly orphan) data blob."""
268    data_length = len(data_blob)
269    data_offset = self.curr_offset
270    self.curr_offset += data_length
271    self.data_blobs.append(data_blob)
272    return data_length, data_offset
273
274  def AddOperationWithData(self, is_kernel, op_type, src_extents=None,
275                           src_length=None, dst_extents=None, dst_length=None,
276                           data_blob=None, do_hash_data_blob=True):
277    """Adds an install operation and associated data blob.
278
279    This takes care of obtaining a hash of the data blob (if so instructed)
280    and appending it to the internally maintained list of blobs, including the
281    necessary offset/length accounting.
282
283    Args:
284      is_kernel: whether this is a kernel (True) or rootfs (False) operation
285      op_type: one of REPLACE, REPLACE_BZ, MOVE or BSDIFF
286      src_extents: list of (start, length) pairs indicating src block ranges
287      src_length: size of the src data in bytes (needed for BSDIFF)
288      dst_extents: list of (start, length) pairs indicating dst block ranges
289      dst_length: size of the dst data in bytes (needed for BSDIFF)
290      data_blob: a data blob associated with this operation
291      do_hash_data_blob: whether or not to compute and add a data blob hash
292    """
293    data_offset = data_length = data_sha256_hash = None
294    if data_blob is not None:
295      if do_hash_data_blob:
296        # pylint: disable=E1101
297        data_sha256_hash = hashlib.sha256(data_blob).digest()
298      data_length, data_offset = self.AddData(data_blob)
299
300    self.AddOperation(is_kernel, op_type, data_offset=data_offset,
301                      data_length=data_length, src_extents=src_extents,
302                      src_length=src_length, dst_extents=dst_extents,
303                      dst_length=dst_length, data_sha256_hash=data_sha256_hash)
304
305  def WriteToFileWithData(self, file_obj, sigs_data=None,
306                          privkey_file_name=None,
307                          do_add_pseudo_operation=False,
308                          is_pseudo_in_kernel=False, padding=None):
309    """Writes the payload content to a file, optionally signing the content.
310
311    Args:
312      file_obj: a file object open for writing
313      sigs_data: signatures blob to be appended to the payload (optional;
314                 payload signature fields assumed to be preset by the caller)
315      privkey_file_name: key used for signing the payload (optional; used only
316                         if explicit signatures blob not provided)
317      do_add_pseudo_operation: whether a pseudo-operation should be added to
318                               account for the signature blob
319      is_pseudo_in_kernel: whether the pseudo-operation should be added to
320                           kernel (True) or rootfs (False) operations
321      padding: stuff to dump past the normal data blobs provided (optional)
322
323    Raises:
324      TestError: if arguments are inconsistent or something goes wrong.
325    """
326    sigs_len = len(sigs_data) if sigs_data else 0
327
328    # Do we need to generate a genuine signatures blob?
329    do_generate_sigs_data = sigs_data is None and privkey_file_name
330
331    if do_generate_sigs_data:
332      # First, sign some arbitrary data to obtain the size of a signature blob.
333      fake_sig = SignSha256('fake-payload-data', privkey_file_name)
334      fake_sigs_gen = SignaturesGenerator()
335      fake_sigs_gen.AddSig(1, fake_sig)
336      sigs_len = len(fake_sigs_gen.ToBinary())
337
338      # Update the payload with proper signature attributes.
339      self.SetSignatures(self.curr_offset, sigs_len)
340
341    # Add a pseudo-operation to account for the signature blob, if requested.
342    if do_add_pseudo_operation:
343      if not self.block_size:
344        raise TestError('cannot add pseudo-operation without knowing the '
345                        'payload block size')
346      self.AddOperation(
347          is_pseudo_in_kernel, common.OpType.REPLACE,
348          data_offset=self.curr_offset, data_length=sigs_len,
349          dst_extents=[(common.PSEUDO_EXTENT_MARKER,
350                        (sigs_len + self.block_size - 1) / self.block_size)])
351
352    if do_generate_sigs_data:
353      # Once all payload fields are updated, dump and sign it.
354      temp_payload_file = cStringIO.StringIO()
355      self.WriteToFile(temp_payload_file, data_blobs=self.data_blobs)
356      sig = SignSha256(temp_payload_file.getvalue(), privkey_file_name)
357      sigs_gen = SignaturesGenerator()
358      sigs_gen.AddSig(1, sig)
359      sigs_data = sigs_gen.ToBinary()
360      assert len(sigs_data) == sigs_len, 'signature blob lengths mismatch'
361
362    # Dump the whole thing, complete with data and signature blob, to a file.
363    self.WriteToFile(file_obj, data_blobs=self.data_blobs, sigs_data=sigs_data,
364                     padding=padding)
365