1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# http://code.google.com/p/protobuf/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38__author__ = 'robinson@google.com (Will Robinson)'
39
40import operator
41import struct
42
43import unittest
44# TODO(robinson): When we split this test in two, only some of these imports
45# will be necessary in each test.
46from google.protobuf import unittest_import_pb2
47from google.protobuf import unittest_mset_pb2
48from google.protobuf import unittest_pb2
49from google.protobuf import descriptor_pb2
50from google.protobuf import descriptor
51from google.protobuf import message
52from google.protobuf import reflection
53from google.protobuf.internal import more_extensions_pb2
54from google.protobuf.internal import more_messages_pb2
55from google.protobuf.internal import wire_format
56from google.protobuf.internal import test_util
57from google.protobuf.internal import decoder
58
59
60class _MiniDecoder(object):
61  """Decodes a stream of values from a string.
62
63  Once upon a time we actually had a class called decoder.Decoder.  Then we
64  got rid of it during a redesign that made decoding much, much faster overall.
65  But a couple tests in this file used it to check that the serialized form of
66  a message was correct.  So, this class implements just the methods that were
67  used by said tests, so that we don't have to rewrite the tests.
68  """
69
70  def __init__(self, bytes):
71    self._bytes = bytes
72    self._pos = 0
73
74  def ReadVarint(self):
75    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
76    return result
77
78  ReadInt32 = ReadVarint
79  ReadInt64 = ReadVarint
80  ReadUInt32 = ReadVarint
81  ReadUInt64 = ReadVarint
82
83  def ReadSInt64(self):
84    return wire_format.ZigZagDecode(self.ReadVarint())
85
86  ReadSInt32 = ReadSInt64
87
88  def ReadFieldNumberAndWireType(self):
89    return wire_format.UnpackTag(self.ReadVarint())
90
91  def ReadFloat(self):
92    result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
93    self._pos += 4
94    return result
95
96  def ReadDouble(self):
97    result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
98    self._pos += 8
99    return result
100
101  def EndOfStream(self):
102    return self._pos == len(self._bytes)
103
104
105class ReflectionTest(unittest.TestCase):
106
107  def assertIs(self, values, others):
108    self.assertEqual(len(values), len(others))
109    for i in range(len(values)):
110      self.assertTrue(values[i] is others[i])
111
112  def testScalarConstructor(self):
113    # Constructor with only scalar types should succeed.
114    proto = unittest_pb2.TestAllTypes(
115        optional_int32=24,
116        optional_double=54.321,
117        optional_string='optional_string')
118
119    self.assertEqual(24, proto.optional_int32)
120    self.assertEqual(54.321, proto.optional_double)
121    self.assertEqual('optional_string', proto.optional_string)
122
123  def testRepeatedScalarConstructor(self):
124    # Constructor with only repeated scalar types should succeed.
125    proto = unittest_pb2.TestAllTypes(
126        repeated_int32=[1, 2, 3, 4],
127        repeated_double=[1.23, 54.321],
128        repeated_bool=[True, False, False],
129        repeated_string=["optional_string"])
130
131    self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
132    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
133    self.assertEquals([True, False, False], list(proto.repeated_bool))
134    self.assertEquals(["optional_string"], list(proto.repeated_string))
135
136  def testRepeatedCompositeConstructor(self):
137    # Constructor with only repeated composite types should succeed.
138    proto = unittest_pb2.TestAllTypes(
139        repeated_nested_message=[
140            unittest_pb2.TestAllTypes.NestedMessage(
141                bb=unittest_pb2.TestAllTypes.FOO),
142            unittest_pb2.TestAllTypes.NestedMessage(
143                bb=unittest_pb2.TestAllTypes.BAR)],
144        repeated_foreign_message=[
145            unittest_pb2.ForeignMessage(c=-43),
146            unittest_pb2.ForeignMessage(c=45324),
147            unittest_pb2.ForeignMessage(c=12)],
148        repeatedgroup=[
149            unittest_pb2.TestAllTypes.RepeatedGroup(),
150            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
151            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
152
153    self.assertEquals(
154        [unittest_pb2.TestAllTypes.NestedMessage(
155            bb=unittest_pb2.TestAllTypes.FOO),
156         unittest_pb2.TestAllTypes.NestedMessage(
157             bb=unittest_pb2.TestAllTypes.BAR)],
158        list(proto.repeated_nested_message))
159    self.assertEquals(
160        [unittest_pb2.ForeignMessage(c=-43),
161         unittest_pb2.ForeignMessage(c=45324),
162         unittest_pb2.ForeignMessage(c=12)],
163        list(proto.repeated_foreign_message))
164    self.assertEquals(
165        [unittest_pb2.TestAllTypes.RepeatedGroup(),
166         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
167         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
168        list(proto.repeatedgroup))
169
170  def testMixedConstructor(self):
171    # Constructor with only mixed types should succeed.
172    proto = unittest_pb2.TestAllTypes(
173        optional_int32=24,
174        optional_string='optional_string',
175        repeated_double=[1.23, 54.321],
176        repeated_bool=[True, False, False],
177        repeated_nested_message=[
178            unittest_pb2.TestAllTypes.NestedMessage(
179                bb=unittest_pb2.TestAllTypes.FOO),
180            unittest_pb2.TestAllTypes.NestedMessage(
181                bb=unittest_pb2.TestAllTypes.BAR)],
182        repeated_foreign_message=[
183            unittest_pb2.ForeignMessage(c=-43),
184            unittest_pb2.ForeignMessage(c=45324),
185            unittest_pb2.ForeignMessage(c=12)])
186
187    self.assertEqual(24, proto.optional_int32)
188    self.assertEqual('optional_string', proto.optional_string)
189    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
190    self.assertEquals([True, False, False], list(proto.repeated_bool))
191    self.assertEquals(
192        [unittest_pb2.TestAllTypes.NestedMessage(
193            bb=unittest_pb2.TestAllTypes.FOO),
194         unittest_pb2.TestAllTypes.NestedMessage(
195             bb=unittest_pb2.TestAllTypes.BAR)],
196        list(proto.repeated_nested_message))
197    self.assertEquals(
198        [unittest_pb2.ForeignMessage(c=-43),
199         unittest_pb2.ForeignMessage(c=45324),
200         unittest_pb2.ForeignMessage(c=12)],
201        list(proto.repeated_foreign_message))
202
203  def testSimpleHasBits(self):
204    # Test a scalar.
205    proto = unittest_pb2.TestAllTypes()
206    self.assertTrue(not proto.HasField('optional_int32'))
207    self.assertEqual(0, proto.optional_int32)
208    # HasField() shouldn't be true if all we've done is
209    # read the default value.
210    self.assertTrue(not proto.HasField('optional_int32'))
211    proto.optional_int32 = 1
212    # Setting a value however *should* set the "has" bit.
213    self.assertTrue(proto.HasField('optional_int32'))
214    proto.ClearField('optional_int32')
215    # And clearing that value should unset the "has" bit.
216    self.assertTrue(not proto.HasField('optional_int32'))
217
218  def testHasBitsWithSinglyNestedScalar(self):
219    # Helper used to test foreign messages and groups.
220    #
221    # composite_field_name should be the name of a non-repeated
222    # composite (i.e., foreign or group) field in TestAllTypes,
223    # and scalar_field_name should be the name of an integer-valued
224    # scalar field within that composite.
225    #
226    # I never thought I'd miss C++ macros and templates so much. :(
227    # This helper is semantically just:
228    #
229    #   assert proto.composite_field.scalar_field == 0
230    #   assert not proto.composite_field.HasField('scalar_field')
231    #   assert not proto.HasField('composite_field')
232    #
233    #   proto.composite_field.scalar_field = 10
234    #   old_composite_field = proto.composite_field
235    #
236    #   assert proto.composite_field.scalar_field == 10
237    #   assert proto.composite_field.HasField('scalar_field')
238    #   assert proto.HasField('composite_field')
239    #
240    #   proto.ClearField('composite_field')
241    #
242    #   assert not proto.composite_field.HasField('scalar_field')
243    #   assert not proto.HasField('composite_field')
244    #   assert proto.composite_field.scalar_field == 0
245    #
246    #   # Now ensure that ClearField('composite_field') disconnected
247    #   # the old field object from the object tree...
248    #   assert old_composite_field is not proto.composite_field
249    #   old_composite_field.scalar_field = 20
250    #   assert not proto.composite_field.HasField('scalar_field')
251    #   assert not proto.HasField('composite_field')
252    def TestCompositeHasBits(composite_field_name, scalar_field_name):
253      proto = unittest_pb2.TestAllTypes()
254      # First, check that we can get the scalar value, and see that it's the
255      # default (0), but that proto.HasField('omposite') and
256      # proto.composite.HasField('scalar') will still return False.
257      composite_field = getattr(proto, composite_field_name)
258      original_scalar_value = getattr(composite_field, scalar_field_name)
259      self.assertEqual(0, original_scalar_value)
260      # Assert that the composite object does not "have" the scalar.
261      self.assertTrue(not composite_field.HasField(scalar_field_name))
262      # Assert that proto does not "have" the composite field.
263      self.assertTrue(not proto.HasField(composite_field_name))
264
265      # Now set the scalar within the composite field.  Ensure that the setting
266      # is reflected, and that proto.HasField('composite') and
267      # proto.composite.HasField('scalar') now both return True.
268      new_val = 20
269      setattr(composite_field, scalar_field_name, new_val)
270      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
271      # Hold on to a reference to the current composite_field object.
272      old_composite_field = composite_field
273      # Assert that the has methods now return true.
274      self.assertTrue(composite_field.HasField(scalar_field_name))
275      self.assertTrue(proto.HasField(composite_field_name))
276
277      # Now call the clear method...
278      proto.ClearField(composite_field_name)
279
280      # ...and ensure that the "has" bits are all back to False...
281      composite_field = getattr(proto, composite_field_name)
282      self.assertTrue(not composite_field.HasField(scalar_field_name))
283      self.assertTrue(not proto.HasField(composite_field_name))
284      # ...and ensure that the scalar field has returned to its default.
285      self.assertEqual(0, getattr(composite_field, scalar_field_name))
286
287      # Finally, ensure that modifications to the old composite field object
288      # don't have any effect on the parent.
289      #
290      # (NOTE that when we clear the composite field in the parent, we actually
291      # don't recursively clear down the tree.  Instead, we just disconnect the
292      # cleared composite from the tree.)
293      self.assertTrue(old_composite_field is not composite_field)
294      setattr(old_composite_field, scalar_field_name, new_val)
295      self.assertTrue(not composite_field.HasField(scalar_field_name))
296      self.assertTrue(not proto.HasField(composite_field_name))
297      self.assertEqual(0, getattr(composite_field, scalar_field_name))
298
299    # Test simple, single-level nesting when we set a scalar.
300    TestCompositeHasBits('optionalgroup', 'a')
301    TestCompositeHasBits('optional_nested_message', 'bb')
302    TestCompositeHasBits('optional_foreign_message', 'c')
303    TestCompositeHasBits('optional_import_message', 'd')
304
305  def testReferencesToNestedMessage(self):
306    proto = unittest_pb2.TestAllTypes()
307    nested = proto.optional_nested_message
308    del proto
309    # A previous version had a bug where this would raise an exception when
310    # hitting a now-dead weak reference.
311    nested.bb = 23
312
313  def testDisconnectingNestedMessageBeforeSettingField(self):
314    proto = unittest_pb2.TestAllTypes()
315    nested = proto.optional_nested_message
316    proto.ClearField('optional_nested_message')  # Should disconnect from parent
317    self.assertTrue(nested is not proto.optional_nested_message)
318    nested.bb = 23
319    self.assertTrue(not proto.HasField('optional_nested_message'))
320    self.assertEqual(0, proto.optional_nested_message.bb)
321
322  def testHasBitsWhenModifyingRepeatedFields(self):
323    # Test nesting when we add an element to a repeated field in a submessage.
324    proto = unittest_pb2.TestNestedMessageHasBits()
325    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
326    self.assertEqual(
327        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
328    self.assertTrue(proto.HasField('optional_nested_message'))
329
330    # Do the same test, but with a repeated composite field within the
331    # submessage.
332    proto.ClearField('optional_nested_message')
333    self.assertTrue(not proto.HasField('optional_nested_message'))
334    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
335    self.assertTrue(proto.HasField('optional_nested_message'))
336
337  def testHasBitsForManyLevelsOfNesting(self):
338    # Test nesting many levels deep.
339    recursive_proto = unittest_pb2.TestMutualRecursionA()
340    self.assertTrue(not recursive_proto.HasField('bb'))
341    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
342    self.assertTrue(not recursive_proto.HasField('bb'))
343    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
344    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
345    self.assertTrue(recursive_proto.HasField('bb'))
346    self.assertTrue(recursive_proto.bb.HasField('a'))
347    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
348    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
349    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
350    self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
351    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
352
353  def testSingularListFields(self):
354    proto = unittest_pb2.TestAllTypes()
355    proto.optional_fixed32 = 1
356    proto.optional_int32 = 5
357    proto.optional_string = 'foo'
358    # Access sub-message but don't set it yet.
359    nested_message = proto.optional_nested_message
360    self.assertEqual(
361      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
362        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
363        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
364      proto.ListFields())
365
366    proto.optional_nested_message.bb = 123
367    self.assertEqual(
368      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
369        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
370        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
371        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
372             nested_message) ],
373      proto.ListFields())
374
375  def testRepeatedListFields(self):
376    proto = unittest_pb2.TestAllTypes()
377    proto.repeated_fixed32.append(1)
378    proto.repeated_int32.append(5)
379    proto.repeated_int32.append(11)
380    proto.repeated_string.extend(['foo', 'bar'])
381    proto.repeated_string.extend([])
382    proto.repeated_string.append('baz')
383    proto.repeated_string.extend(str(x) for x in xrange(2))
384    proto.optional_int32 = 21
385    proto.repeated_bool  # Access but don't set anything; should not be listed.
386    self.assertEqual(
387      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
388        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
389        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
390        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
391          ['foo', 'bar', 'baz', '0', '1']) ],
392      proto.ListFields())
393
394  def testSingularListExtensions(self):
395    proto = unittest_pb2.TestAllExtensions()
396    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
397    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
398    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
399    self.assertEqual(
400      [ (unittest_pb2.optional_int32_extension  , 5),
401        (unittest_pb2.optional_fixed32_extension, 1),
402        (unittest_pb2.optional_string_extension , 'foo') ],
403      proto.ListFields())
404
405  def testRepeatedListExtensions(self):
406    proto = unittest_pb2.TestAllExtensions()
407    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
408    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
409    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
410    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
411    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
412    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
413    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
414    self.assertEqual(
415      [ (unittest_pb2.optional_int32_extension  , 21),
416        (unittest_pb2.repeated_int32_extension  , [5, 11]),
417        (unittest_pb2.repeated_fixed32_extension, [1]),
418        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
419      proto.ListFields())
420
421  def testListFieldsAndExtensions(self):
422    proto = unittest_pb2.TestFieldOrderings()
423    test_util.SetAllFieldsAndExtensions(proto)
424    unittest_pb2.my_extension_int
425    self.assertEqual(
426      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
427        (unittest_pb2.my_extension_int               , 23),
428        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
429        (unittest_pb2.my_extension_string            , 'bar'),
430        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
431      proto.ListFields())
432
433  def testDefaultValues(self):
434    proto = unittest_pb2.TestAllTypes()
435    self.assertEqual(0, proto.optional_int32)
436    self.assertEqual(0, proto.optional_int64)
437    self.assertEqual(0, proto.optional_uint32)
438    self.assertEqual(0, proto.optional_uint64)
439    self.assertEqual(0, proto.optional_sint32)
440    self.assertEqual(0, proto.optional_sint64)
441    self.assertEqual(0, proto.optional_fixed32)
442    self.assertEqual(0, proto.optional_fixed64)
443    self.assertEqual(0, proto.optional_sfixed32)
444    self.assertEqual(0, proto.optional_sfixed64)
445    self.assertEqual(0.0, proto.optional_float)
446    self.assertEqual(0.0, proto.optional_double)
447    self.assertEqual(False, proto.optional_bool)
448    self.assertEqual('', proto.optional_string)
449    self.assertEqual('', proto.optional_bytes)
450
451    self.assertEqual(41, proto.default_int32)
452    self.assertEqual(42, proto.default_int64)
453    self.assertEqual(43, proto.default_uint32)
454    self.assertEqual(44, proto.default_uint64)
455    self.assertEqual(-45, proto.default_sint32)
456    self.assertEqual(46, proto.default_sint64)
457    self.assertEqual(47, proto.default_fixed32)
458    self.assertEqual(48, proto.default_fixed64)
459    self.assertEqual(49, proto.default_sfixed32)
460    self.assertEqual(-50, proto.default_sfixed64)
461    self.assertEqual(51.5, proto.default_float)
462    self.assertEqual(52e3, proto.default_double)
463    self.assertEqual(True, proto.default_bool)
464    self.assertEqual('hello', proto.default_string)
465    self.assertEqual('world', proto.default_bytes)
466    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
467    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
468    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
469                     proto.default_import_enum)
470
471    proto = unittest_pb2.TestExtremeDefaultValues()
472    self.assertEqual(u'\u1234', proto.utf8_string)
473
474  def testHasFieldWithUnknownFieldName(self):
475    proto = unittest_pb2.TestAllTypes()
476    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
477
478  def testClearFieldWithUnknownFieldName(self):
479    proto = unittest_pb2.TestAllTypes()
480    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
481
482  def testDisallowedAssignments(self):
483    # It's illegal to assign values directly to repeated fields
484    # or to nonrepeated composite fields.  Ensure that this fails.
485    proto = unittest_pb2.TestAllTypes()
486    # Repeated fields.
487    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
488    # Lists shouldn't work, either.
489    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
490    # Composite fields.
491    self.assertRaises(AttributeError, setattr, proto,
492                      'optional_nested_message', 23)
493    # Assignment to a repeated nested message field without specifying
494    # the index in the array of nested messages.
495    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
496                      'bb', 34)
497    # Assignment to an attribute of a repeated field.
498    self.assertRaises(AttributeError, setattr, proto.repeated_float,
499                      'some_attribute', 34)
500    # proto.nonexistent_field = 23 should fail as well.
501    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
502
503  # TODO(robinson): Add type-safety check for enums.
504  def testSingleScalarTypeSafety(self):
505    proto = unittest_pb2.TestAllTypes()
506    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
507    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
508    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
509    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
510
511  def testSingleScalarBoundsChecking(self):
512    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
513      pb = unittest_pb2.TestAllTypes()
514      setattr(pb, field_name, expected_min)
515      setattr(pb, field_name, expected_max)
516      self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
517      self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
518
519    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
520    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
521    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
522    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
523    TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
524
525  def testRepeatedScalarTypeSafety(self):
526    proto = unittest_pb2.TestAllTypes()
527    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
528    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
529    self.assertRaises(TypeError, proto.repeated_string, 10)
530    self.assertRaises(TypeError, proto.repeated_bytes, 10)
531
532    proto.repeated_int32.append(10)
533    proto.repeated_int32[0] = 23
534    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
535    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
536
537  def testSingleScalarGettersAndSetters(self):
538    proto = unittest_pb2.TestAllTypes()
539    self.assertEqual(0, proto.optional_int32)
540    proto.optional_int32 = 1
541    self.assertEqual(1, proto.optional_int32)
542    # TODO(robinson): Test all other scalar field types.
543
544  def testSingleScalarClearField(self):
545    proto = unittest_pb2.TestAllTypes()
546    # Should be allowed to clear something that's not there (a no-op).
547    proto.ClearField('optional_int32')
548    proto.optional_int32 = 1
549    self.assertTrue(proto.HasField('optional_int32'))
550    proto.ClearField('optional_int32')
551    self.assertEqual(0, proto.optional_int32)
552    self.assertTrue(not proto.HasField('optional_int32'))
553    # TODO(robinson): Test all other scalar field types.
554
555  def testEnums(self):
556    proto = unittest_pb2.TestAllTypes()
557    self.assertEqual(1, proto.FOO)
558    self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
559    self.assertEqual(2, proto.BAR)
560    self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
561    self.assertEqual(3, proto.BAZ)
562    self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
563
564  def testRepeatedScalars(self):
565    proto = unittest_pb2.TestAllTypes()
566
567    self.assertTrue(not proto.repeated_int32)
568    self.assertEqual(0, len(proto.repeated_int32))
569    proto.repeated_int32.append(5)
570    proto.repeated_int32.append(10)
571    proto.repeated_int32.append(15)
572    self.assertTrue(proto.repeated_int32)
573    self.assertEqual(3, len(proto.repeated_int32))
574
575    self.assertEqual([5, 10, 15], proto.repeated_int32)
576
577    # Test single retrieval.
578    self.assertEqual(5, proto.repeated_int32[0])
579    self.assertEqual(15, proto.repeated_int32[-1])
580    # Test out-of-bounds indices.
581    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
582    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
583    # Test incorrect types passed to __getitem__.
584    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
585    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
586
587    # Test single assignment.
588    proto.repeated_int32[1] = 20
589    self.assertEqual([5, 20, 15], proto.repeated_int32)
590
591    # Test insertion.
592    proto.repeated_int32.insert(1, 25)
593    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
594
595    # Test slice retrieval.
596    proto.repeated_int32.append(30)
597    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
598    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
599
600    # Test slice assignment with an iterator
601    proto.repeated_int32[1:4] = (i for i in xrange(3))
602    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
603
604    # Test slice assignment.
605    proto.repeated_int32[1:4] = [35, 40, 45]
606    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
607
608    # Test that we can use the field as an iterator.
609    result = []
610    for i in proto.repeated_int32:
611      result.append(i)
612    self.assertEqual([5, 35, 40, 45, 30], result)
613
614    # Test single deletion.
615    del proto.repeated_int32[2]
616    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
617
618    # Test slice deletion.
619    del proto.repeated_int32[2:]
620    self.assertEqual([5, 35], proto.repeated_int32)
621
622    # Test clearing.
623    proto.ClearField('repeated_int32')
624    self.assertTrue(not proto.repeated_int32)
625    self.assertEqual(0, len(proto.repeated_int32))
626
627  def testRepeatedScalarsRemove(self):
628    proto = unittest_pb2.TestAllTypes()
629
630    self.assertTrue(not proto.repeated_int32)
631    self.assertEqual(0, len(proto.repeated_int32))
632    proto.repeated_int32.append(5)
633    proto.repeated_int32.append(10)
634    proto.repeated_int32.append(5)
635    proto.repeated_int32.append(5)
636
637    self.assertEqual(4, len(proto.repeated_int32))
638    proto.repeated_int32.remove(5)
639    self.assertEqual(3, len(proto.repeated_int32))
640    self.assertEqual(10, proto.repeated_int32[0])
641    self.assertEqual(5, proto.repeated_int32[1])
642    self.assertEqual(5, proto.repeated_int32[2])
643
644    proto.repeated_int32.remove(5)
645    self.assertEqual(2, len(proto.repeated_int32))
646    self.assertEqual(10, proto.repeated_int32[0])
647    self.assertEqual(5, proto.repeated_int32[1])
648
649    proto.repeated_int32.remove(10)
650    self.assertEqual(1, len(proto.repeated_int32))
651    self.assertEqual(5, proto.repeated_int32[0])
652
653    # Remove a non-existent element.
654    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
655
656  def testRepeatedComposites(self):
657    proto = unittest_pb2.TestAllTypes()
658    self.assertTrue(not proto.repeated_nested_message)
659    self.assertEqual(0, len(proto.repeated_nested_message))
660    m0 = proto.repeated_nested_message.add()
661    m1 = proto.repeated_nested_message.add()
662    self.assertTrue(proto.repeated_nested_message)
663    self.assertEqual(2, len(proto.repeated_nested_message))
664    self.assertIs([m0, m1], proto.repeated_nested_message)
665    self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
666
667    # Test out-of-bounds indices.
668    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
669                      1234)
670    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
671                      -1234)
672
673    # Test incorrect types passed to __getitem__.
674    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
675                      'foo')
676    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
677                      None)
678
679    # Test slice retrieval.
680    m2 = proto.repeated_nested_message.add()
681    m3 = proto.repeated_nested_message.add()
682    m4 = proto.repeated_nested_message.add()
683    self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4])
684    self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
685
686    # Test that we can use the field as an iterator.
687    result = []
688    for i in proto.repeated_nested_message:
689      result.append(i)
690    self.assertIs([m0, m1, m2, m3, m4], result)
691
692    # Test single deletion.
693    del proto.repeated_nested_message[2]
694    self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message)
695
696    # Test slice deletion.
697    del proto.repeated_nested_message[2:]
698    self.assertIs([m0, m1], proto.repeated_nested_message)
699
700    # Test clearing.
701    proto.ClearField('repeated_nested_message')
702    self.assertTrue(not proto.repeated_nested_message)
703    self.assertEqual(0, len(proto.repeated_nested_message))
704
705  def testHandWrittenReflection(self):
706    # TODO(robinson): We probably need a better way to specify
707    # protocol types by hand.  But then again, this isn't something
708    # we expect many people to do.  Hmm.
709    FieldDescriptor = descriptor.FieldDescriptor
710    foo_field_descriptor = FieldDescriptor(
711        name='foo_field', full_name='MyProto.foo_field',
712        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
713        cpp_type=FieldDescriptor.CPPTYPE_INT64,
714        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
715        containing_type=None, message_type=None, enum_type=None,
716        is_extension=False, extension_scope=None,
717        options=descriptor_pb2.FieldOptions())
718    mydescriptor = descriptor.Descriptor(
719        name='MyProto', full_name='MyProto', filename='ignored',
720        containing_type=None, nested_types=[], enum_types=[],
721        fields=[foo_field_descriptor], extensions=[],
722        options=descriptor_pb2.MessageOptions())
723    class MyProtoClass(message.Message):
724      DESCRIPTOR = mydescriptor
725      __metaclass__ = reflection.GeneratedProtocolMessageType
726    myproto_instance = MyProtoClass()
727    self.assertEqual(0, myproto_instance.foo_field)
728    self.assertTrue(not myproto_instance.HasField('foo_field'))
729    myproto_instance.foo_field = 23
730    self.assertEqual(23, myproto_instance.foo_field)
731    self.assertTrue(myproto_instance.HasField('foo_field'))
732
733  def testTopLevelExtensionsForOptionalScalar(self):
734    extendee_proto = unittest_pb2.TestAllExtensions()
735    extension = unittest_pb2.optional_int32_extension
736    self.assertTrue(not extendee_proto.HasExtension(extension))
737    self.assertEqual(0, extendee_proto.Extensions[extension])
738    # As with normal scalar fields, just doing a read doesn't actually set the
739    # "has" bit.
740    self.assertTrue(not extendee_proto.HasExtension(extension))
741    # Actually set the thing.
742    extendee_proto.Extensions[extension] = 23
743    self.assertEqual(23, extendee_proto.Extensions[extension])
744    self.assertTrue(extendee_proto.HasExtension(extension))
745    # Ensure that clearing works as well.
746    extendee_proto.ClearExtension(extension)
747    self.assertEqual(0, extendee_proto.Extensions[extension])
748    self.assertTrue(not extendee_proto.HasExtension(extension))
749
750  def testTopLevelExtensionsForRepeatedScalar(self):
751    extendee_proto = unittest_pb2.TestAllExtensions()
752    extension = unittest_pb2.repeated_string_extension
753    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
754    extendee_proto.Extensions[extension].append('foo')
755    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
756    string_list = extendee_proto.Extensions[extension]
757    extendee_proto.ClearExtension(extension)
758    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
759    self.assertTrue(string_list is not extendee_proto.Extensions[extension])
760    # Shouldn't be allowed to do Extensions[extension] = 'a'
761    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
762                      extension, 'a')
763
764  def testTopLevelExtensionsForOptionalMessage(self):
765    extendee_proto = unittest_pb2.TestAllExtensions()
766    extension = unittest_pb2.optional_foreign_message_extension
767    self.assertTrue(not extendee_proto.HasExtension(extension))
768    self.assertEqual(0, extendee_proto.Extensions[extension].c)
769    # As with normal (non-extension) fields, merely reading from the
770    # thing shouldn't set the "has" bit.
771    self.assertTrue(not extendee_proto.HasExtension(extension))
772    extendee_proto.Extensions[extension].c = 23
773    self.assertEqual(23, extendee_proto.Extensions[extension].c)
774    self.assertTrue(extendee_proto.HasExtension(extension))
775    # Save a reference here.
776    foreign_message = extendee_proto.Extensions[extension]
777    extendee_proto.ClearExtension(extension)
778    self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
779    # Setting a field on foreign_message now shouldn't set
780    # any "has" bits on extendee_proto.
781    foreign_message.c = 42
782    self.assertEqual(42, foreign_message.c)
783    self.assertTrue(foreign_message.HasField('c'))
784    self.assertTrue(not extendee_proto.HasExtension(extension))
785    # Shouldn't be allowed to do Extensions[extension] = 'a'
786    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
787                      extension, 'a')
788
789  def testTopLevelExtensionsForRepeatedMessage(self):
790    extendee_proto = unittest_pb2.TestAllExtensions()
791    extension = unittest_pb2.repeatedgroup_extension
792    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
793    group = extendee_proto.Extensions[extension].add()
794    group.a = 23
795    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
796    group.a = 42
797    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
798    group_list = extendee_proto.Extensions[extension]
799    extendee_proto.ClearExtension(extension)
800    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
801    self.assertTrue(group_list is not extendee_proto.Extensions[extension])
802    # Shouldn't be allowed to do Extensions[extension] = 'a'
803    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
804                      extension, 'a')
805
806  def testNestedExtensions(self):
807    extendee_proto = unittest_pb2.TestAllExtensions()
808    extension = unittest_pb2.TestRequired.single
809
810    # We just test the non-repeated case.
811    self.assertTrue(not extendee_proto.HasExtension(extension))
812    required = extendee_proto.Extensions[extension]
813    self.assertEqual(0, required.a)
814    self.assertTrue(not extendee_proto.HasExtension(extension))
815    required.a = 23
816    self.assertEqual(23, extendee_proto.Extensions[extension].a)
817    self.assertTrue(extendee_proto.HasExtension(extension))
818    extendee_proto.ClearExtension(extension)
819    self.assertTrue(required is not extendee_proto.Extensions[extension])
820    self.assertTrue(not extendee_proto.HasExtension(extension))
821
822  # If message A directly contains message B, and
823  # a.HasField('b') is currently False, then mutating any
824  # extension in B should change a.HasField('b') to True
825  # (and so on up the object tree).
826  def testHasBitsForAncestorsOfExtendedMessage(self):
827    # Optional scalar extension.
828    toplevel = more_extensions_pb2.TopLevelMessage()
829    self.assertTrue(not toplevel.HasField('submessage'))
830    self.assertEqual(0, toplevel.submessage.Extensions[
831        more_extensions_pb2.optional_int_extension])
832    self.assertTrue(not toplevel.HasField('submessage'))
833    toplevel.submessage.Extensions[
834        more_extensions_pb2.optional_int_extension] = 23
835    self.assertEqual(23, toplevel.submessage.Extensions[
836        more_extensions_pb2.optional_int_extension])
837    self.assertTrue(toplevel.HasField('submessage'))
838
839    # Repeated scalar extension.
840    toplevel = more_extensions_pb2.TopLevelMessage()
841    self.assertTrue(not toplevel.HasField('submessage'))
842    self.assertEqual([], toplevel.submessage.Extensions[
843        more_extensions_pb2.repeated_int_extension])
844    self.assertTrue(not toplevel.HasField('submessage'))
845    toplevel.submessage.Extensions[
846        more_extensions_pb2.repeated_int_extension].append(23)
847    self.assertEqual([23], toplevel.submessage.Extensions[
848        more_extensions_pb2.repeated_int_extension])
849    self.assertTrue(toplevel.HasField('submessage'))
850
851    # Optional message extension.
852    toplevel = more_extensions_pb2.TopLevelMessage()
853    self.assertTrue(not toplevel.HasField('submessage'))
854    self.assertEqual(0, toplevel.submessage.Extensions[
855        more_extensions_pb2.optional_message_extension].foreign_message_int)
856    self.assertTrue(not toplevel.HasField('submessage'))
857    toplevel.submessage.Extensions[
858        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
859    self.assertEqual(23, toplevel.submessage.Extensions[
860        more_extensions_pb2.optional_message_extension].foreign_message_int)
861    self.assertTrue(toplevel.HasField('submessage'))
862
863    # Repeated message extension.
864    toplevel = more_extensions_pb2.TopLevelMessage()
865    self.assertTrue(not toplevel.HasField('submessage'))
866    self.assertEqual(0, len(toplevel.submessage.Extensions[
867        more_extensions_pb2.repeated_message_extension]))
868    self.assertTrue(not toplevel.HasField('submessage'))
869    foreign = toplevel.submessage.Extensions[
870        more_extensions_pb2.repeated_message_extension].add()
871    self.assertTrue(foreign is toplevel.submessage.Extensions[
872        more_extensions_pb2.repeated_message_extension][0])
873    self.assertTrue(toplevel.HasField('submessage'))
874
875  def testDisconnectionAfterClearingEmptyMessage(self):
876    toplevel = more_extensions_pb2.TopLevelMessage()
877    extendee_proto = toplevel.submessage
878    extension = more_extensions_pb2.optional_message_extension
879    extension_proto = extendee_proto.Extensions[extension]
880    extendee_proto.ClearExtension(extension)
881    extension_proto.foreign_message_int = 23
882
883    self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
884
885  def testExtensionFailureModes(self):
886    extendee_proto = unittest_pb2.TestAllExtensions()
887
888    # Try non-extension-handle arguments to HasExtension,
889    # ClearExtension(), and Extensions[]...
890    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
891    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
892    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
893    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
894
895    # Try something that *is* an extension handle, just not for
896    # this message...
897    unknown_handle = more_extensions_pb2.optional_int_extension
898    self.assertRaises(KeyError, extendee_proto.HasExtension,
899                      unknown_handle)
900    self.assertRaises(KeyError, extendee_proto.ClearExtension,
901                      unknown_handle)
902    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
903                      unknown_handle)
904    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
905                      unknown_handle, 5)
906
907    # Try call HasExtension() with a valid handle, but for a
908    # *repeated* field.  (Just as with non-extension repeated
909    # fields, Has*() isn't supported for extension repeated fields).
910    self.assertRaises(KeyError, extendee_proto.HasExtension,
911                      unittest_pb2.repeated_string_extension)
912
913  def testStaticParseFrom(self):
914    proto1 = unittest_pb2.TestAllTypes()
915    test_util.SetAllFields(proto1)
916
917    string1 = proto1.SerializeToString()
918    proto2 = unittest_pb2.TestAllTypes.FromString(string1)
919
920    # Messages should be equal.
921    self.assertEqual(proto2, proto1)
922
923  def testMergeFromSingularField(self):
924    # Test merge with just a singular field.
925    proto1 = unittest_pb2.TestAllTypes()
926    proto1.optional_int32 = 1
927
928    proto2 = unittest_pb2.TestAllTypes()
929    # This shouldn't get overwritten.
930    proto2.optional_string = 'value'
931
932    proto2.MergeFrom(proto1)
933    self.assertEqual(1, proto2.optional_int32)
934    self.assertEqual('value', proto2.optional_string)
935
936  def testMergeFromRepeatedField(self):
937    # Test merge with just a repeated field.
938    proto1 = unittest_pb2.TestAllTypes()
939    proto1.repeated_int32.append(1)
940    proto1.repeated_int32.append(2)
941
942    proto2 = unittest_pb2.TestAllTypes()
943    proto2.repeated_int32.append(0)
944    proto2.MergeFrom(proto1)
945
946    self.assertEqual(0, proto2.repeated_int32[0])
947    self.assertEqual(1, proto2.repeated_int32[1])
948    self.assertEqual(2, proto2.repeated_int32[2])
949
950  def testMergeFromOptionalGroup(self):
951    # Test merge with an optional group.
952    proto1 = unittest_pb2.TestAllTypes()
953    proto1.optionalgroup.a = 12
954    proto2 = unittest_pb2.TestAllTypes()
955    proto2.MergeFrom(proto1)
956    self.assertEqual(12, proto2.optionalgroup.a)
957
958  def testMergeFromRepeatedNestedMessage(self):
959    # Test merge with a repeated nested message.
960    proto1 = unittest_pb2.TestAllTypes()
961    m = proto1.repeated_nested_message.add()
962    m.bb = 123
963    m = proto1.repeated_nested_message.add()
964    m.bb = 321
965
966    proto2 = unittest_pb2.TestAllTypes()
967    m = proto2.repeated_nested_message.add()
968    m.bb = 999
969    proto2.MergeFrom(proto1)
970    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
971    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
972    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
973
974  def testMergeFromAllFields(self):
975    # With all fields set.
976    proto1 = unittest_pb2.TestAllTypes()
977    test_util.SetAllFields(proto1)
978    proto2 = unittest_pb2.TestAllTypes()
979    proto2.MergeFrom(proto1)
980
981    # Messages should be equal.
982    self.assertEqual(proto2, proto1)
983
984    # Serialized string should be equal too.
985    string1 = proto1.SerializeToString()
986    string2 = proto2.SerializeToString()
987    self.assertEqual(string1, string2)
988
989  def testMergeFromExtensionsSingular(self):
990    proto1 = unittest_pb2.TestAllExtensions()
991    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
992
993    proto2 = unittest_pb2.TestAllExtensions()
994    proto2.MergeFrom(proto1)
995    self.assertEqual(
996        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
997
998  def testMergeFromExtensionsRepeated(self):
999    proto1 = unittest_pb2.TestAllExtensions()
1000    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1001    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1002
1003    proto2 = unittest_pb2.TestAllExtensions()
1004    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1005    proto2.MergeFrom(proto1)
1006    self.assertEqual(
1007        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1008    self.assertEqual(
1009        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1010    self.assertEqual(
1011        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1012    self.assertEqual(
1013        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1014
1015  def testMergeFromExtensionsNestedMessage(self):
1016    proto1 = unittest_pb2.TestAllExtensions()
1017    ext1 = proto1.Extensions[
1018        unittest_pb2.repeated_nested_message_extension]
1019    m = ext1.add()
1020    m.bb = 222
1021    m = ext1.add()
1022    m.bb = 333
1023
1024    proto2 = unittest_pb2.TestAllExtensions()
1025    ext2 = proto2.Extensions[
1026        unittest_pb2.repeated_nested_message_extension]
1027    m = ext2.add()
1028    m.bb = 111
1029
1030    proto2.MergeFrom(proto1)
1031    ext2 = proto2.Extensions[
1032        unittest_pb2.repeated_nested_message_extension]
1033    self.assertEqual(3, len(ext2))
1034    self.assertEqual(111, ext2[0].bb)
1035    self.assertEqual(222, ext2[1].bb)
1036    self.assertEqual(333, ext2[2].bb)
1037
1038  def testCopyFromSingularField(self):
1039    # Test copy with just a singular field.
1040    proto1 = unittest_pb2.TestAllTypes()
1041    proto1.optional_int32 = 1
1042    proto1.optional_string = 'important-text'
1043
1044    proto2 = unittest_pb2.TestAllTypes()
1045    proto2.optional_string = 'value'
1046
1047    proto2.CopyFrom(proto1)
1048    self.assertEqual(1, proto2.optional_int32)
1049    self.assertEqual('important-text', proto2.optional_string)
1050
1051  def testCopyFromRepeatedField(self):
1052    # Test copy with a repeated field.
1053    proto1 = unittest_pb2.TestAllTypes()
1054    proto1.repeated_int32.append(1)
1055    proto1.repeated_int32.append(2)
1056
1057    proto2 = unittest_pb2.TestAllTypes()
1058    proto2.repeated_int32.append(0)
1059    proto2.CopyFrom(proto1)
1060
1061    self.assertEqual(1, proto2.repeated_int32[0])
1062    self.assertEqual(2, proto2.repeated_int32[1])
1063
1064  def testCopyFromAllFields(self):
1065    # With all fields set.
1066    proto1 = unittest_pb2.TestAllTypes()
1067    test_util.SetAllFields(proto1)
1068    proto2 = unittest_pb2.TestAllTypes()
1069    proto2.CopyFrom(proto1)
1070
1071    # Messages should be equal.
1072    self.assertEqual(proto2, proto1)
1073
1074    # Serialized string should be equal too.
1075    string1 = proto1.SerializeToString()
1076    string2 = proto2.SerializeToString()
1077    self.assertEqual(string1, string2)
1078
1079  def testCopyFromSelf(self):
1080    proto1 = unittest_pb2.TestAllTypes()
1081    proto1.repeated_int32.append(1)
1082    proto1.optional_int32 = 2
1083    proto1.optional_string = 'important-text'
1084
1085    proto1.CopyFrom(proto1)
1086    self.assertEqual(1, proto1.repeated_int32[0])
1087    self.assertEqual(2, proto1.optional_int32)
1088    self.assertEqual('important-text', proto1.optional_string)
1089
1090  def testClear(self):
1091    proto = unittest_pb2.TestAllTypes()
1092    test_util.SetAllFields(proto)
1093    # Clear the message.
1094    proto.Clear()
1095    self.assertEquals(proto.ByteSize(), 0)
1096    empty_proto = unittest_pb2.TestAllTypes()
1097    self.assertEquals(proto, empty_proto)
1098
1099    # Test if extensions which were set are cleared.
1100    proto = unittest_pb2.TestAllExtensions()
1101    test_util.SetAllExtensions(proto)
1102    # Clear the message.
1103    proto.Clear()
1104    self.assertEquals(proto.ByteSize(), 0)
1105    empty_proto = unittest_pb2.TestAllExtensions()
1106    self.assertEquals(proto, empty_proto)
1107
1108  def assertInitialized(self, proto):
1109    self.assertTrue(proto.IsInitialized())
1110    # Neither method should raise an exception.
1111    proto.SerializeToString()
1112    proto.SerializePartialToString()
1113
1114  def assertNotInitialized(self, proto):
1115    self.assertFalse(proto.IsInitialized())
1116    self.assertRaises(message.EncodeError, proto.SerializeToString)
1117    # "Partial" serialization doesn't care if message is uninitialized.
1118    proto.SerializePartialToString()
1119
1120  def testIsInitialized(self):
1121    # Trivial cases - all optional fields and extensions.
1122    proto = unittest_pb2.TestAllTypes()
1123    self.assertInitialized(proto)
1124    proto = unittest_pb2.TestAllExtensions()
1125    self.assertInitialized(proto)
1126
1127    # The case of uninitialized required fields.
1128    proto = unittest_pb2.TestRequired()
1129    self.assertNotInitialized(proto)
1130    proto.a = proto.b = proto.c = 2
1131    self.assertInitialized(proto)
1132
1133    # The case of uninitialized submessage.
1134    proto = unittest_pb2.TestRequiredForeign()
1135    self.assertInitialized(proto)
1136    proto.optional_message.a = 1
1137    self.assertNotInitialized(proto)
1138    proto.optional_message.b = 0
1139    proto.optional_message.c = 0
1140    self.assertInitialized(proto)
1141
1142    # Uninitialized repeated submessage.
1143    message1 = proto.repeated_message.add()
1144    self.assertNotInitialized(proto)
1145    message1.a = message1.b = message1.c = 0
1146    self.assertInitialized(proto)
1147
1148    # Uninitialized repeated group in an extension.
1149    proto = unittest_pb2.TestAllExtensions()
1150    extension = unittest_pb2.TestRequired.multi
1151    message1 = proto.Extensions[extension].add()
1152    message2 = proto.Extensions[extension].add()
1153    self.assertNotInitialized(proto)
1154    message1.a = 1
1155    message1.b = 1
1156    message1.c = 1
1157    self.assertNotInitialized(proto)
1158    message2.a = 2
1159    message2.b = 2
1160    message2.c = 2
1161    self.assertInitialized(proto)
1162
1163    # Uninitialized nonrepeated message in an extension.
1164    proto = unittest_pb2.TestAllExtensions()
1165    extension = unittest_pb2.TestRequired.single
1166    proto.Extensions[extension].a = 1
1167    self.assertNotInitialized(proto)
1168    proto.Extensions[extension].b = 2
1169    proto.Extensions[extension].c = 3
1170    self.assertInitialized(proto)
1171
1172    # Try passing an errors list.
1173    errors = []
1174    proto = unittest_pb2.TestRequired()
1175    self.assertFalse(proto.IsInitialized(errors))
1176    self.assertEqual(errors, ['a', 'b', 'c'])
1177
1178  def testStringUTF8Encoding(self):
1179    proto = unittest_pb2.TestAllTypes()
1180
1181    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
1182    self.assertRaises(TypeError,
1183                      setattr, proto, 'optional_bytes', u'unicode object')
1184
1185    # Check that the default value is of python's 'unicode' type.
1186    self.assertEqual(type(proto.optional_string), unicode)
1187
1188    proto.optional_string = unicode('Testing')
1189    self.assertEqual(proto.optional_string, str('Testing'))
1190
1191    # Assign a value of type 'str' which can be encoded in UTF-8.
1192    proto.optional_string = str('Testing')
1193    self.assertEqual(proto.optional_string, unicode('Testing'))
1194
1195    # Values of type 'str' are also accepted as long as they can be encoded in
1196    # UTF-8.
1197    self.assertEqual(type(proto.optional_string), str)
1198
1199    # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
1200    self.assertRaises(ValueError,
1201                      setattr, proto, 'optional_string', str('a\x80a'))
1202    # Assign a 'str' object which contains a UTF-8 encoded string.
1203    self.assertRaises(ValueError,
1204                      setattr, proto, 'optional_string', 'Тест')
1205    # No exception thrown.
1206    proto.optional_string = 'abc'
1207
1208  def testStringUTF8Serialization(self):
1209    proto = unittest_mset_pb2.TestMessageSet()
1210    extension_message = unittest_mset_pb2.TestMessageSetExtension2
1211    extension = extension_message.message_set_extension
1212
1213    test_utf8 = u'Тест'
1214    test_utf8_bytes = test_utf8.encode('utf-8')
1215
1216    # 'Test' in another language, using UTF-8 charset.
1217    proto.Extensions[extension].str = test_utf8
1218
1219    # Serialize using the MessageSet wire format (this is specified in the
1220    # .proto file).
1221    serialized = proto.SerializeToString()
1222
1223    # Check byte size.
1224    self.assertEqual(proto.ByteSize(), len(serialized))
1225
1226    raw = unittest_mset_pb2.RawMessageSet()
1227    raw.MergeFromString(serialized)
1228
1229    message2 = unittest_mset_pb2.TestMessageSetExtension2()
1230
1231    self.assertEqual(1, len(raw.item))
1232    # Check that the type_id is the same as the tag ID in the .proto file.
1233    self.assertEqual(raw.item[0].type_id, 1547769)
1234
1235    # Check the actually bytes on the wire.
1236    self.assertTrue(
1237        raw.item[0].message.endswith(test_utf8_bytes))
1238    message2.MergeFromString(raw.item[0].message)
1239
1240    self.assertEqual(type(message2.str), unicode)
1241    self.assertEqual(message2.str, test_utf8)
1242
1243    # How about if the bytes on the wire aren't a valid UTF-8 encoded string.
1244    bytes = raw.item[0].message.replace(
1245        test_utf8_bytes, len(test_utf8_bytes) * '\xff')
1246    self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
1247
1248  def testEmptyNestedMessage(self):
1249    proto = unittest_pb2.TestAllTypes()
1250    proto.optional_nested_message.MergeFrom(
1251        unittest_pb2.TestAllTypes.NestedMessage())
1252    self.assertTrue(proto.HasField('optional_nested_message'))
1253
1254    proto = unittest_pb2.TestAllTypes()
1255    proto.optional_nested_message.CopyFrom(
1256        unittest_pb2.TestAllTypes.NestedMessage())
1257    self.assertTrue(proto.HasField('optional_nested_message'))
1258
1259    proto = unittest_pb2.TestAllTypes()
1260    proto.optional_nested_message.MergeFromString('')
1261    self.assertTrue(proto.HasField('optional_nested_message'))
1262
1263    proto = unittest_pb2.TestAllTypes()
1264    proto.optional_nested_message.ParseFromString('')
1265    self.assertTrue(proto.HasField('optional_nested_message'))
1266
1267    serialized = proto.SerializeToString()
1268    proto2 = unittest_pb2.TestAllTypes()
1269    proto2.MergeFromString(serialized)
1270    self.assertTrue(proto2.HasField('optional_nested_message'))
1271
1272  def testSetInParent(self):
1273    proto = unittest_pb2.TestAllTypes()
1274    self.assertFalse(proto.HasField('optionalgroup'))
1275    proto.optionalgroup.SetInParent()
1276    self.assertTrue(proto.HasField('optionalgroup'))
1277
1278
1279#  Since we had so many tests for protocol buffer equality, we broke these out
1280#  into separate TestCase classes.
1281
1282
1283class TestAllTypesEqualityTest(unittest.TestCase):
1284
1285  def setUp(self):
1286    self.first_proto = unittest_pb2.TestAllTypes()
1287    self.second_proto = unittest_pb2.TestAllTypes()
1288
1289  def testSelfEquality(self):
1290    self.assertEqual(self.first_proto, self.first_proto)
1291
1292  def testEmptyProtosEqual(self):
1293    self.assertEqual(self.first_proto, self.second_proto)
1294
1295
1296class FullProtosEqualityTest(unittest.TestCase):
1297
1298  """Equality tests using completely-full protos as a starting point."""
1299
1300  def setUp(self):
1301    self.first_proto = unittest_pb2.TestAllTypes()
1302    self.second_proto = unittest_pb2.TestAllTypes()
1303    test_util.SetAllFields(self.first_proto)
1304    test_util.SetAllFields(self.second_proto)
1305
1306  def testNoneNotEqual(self):
1307    self.assertNotEqual(self.first_proto, None)
1308    self.assertNotEqual(None, self.second_proto)
1309
1310  def testNotEqualToOtherMessage(self):
1311    third_proto = unittest_pb2.TestRequired()
1312    self.assertNotEqual(self.first_proto, third_proto)
1313    self.assertNotEqual(third_proto, self.second_proto)
1314
1315  def testAllFieldsFilledEquality(self):
1316    self.assertEqual(self.first_proto, self.second_proto)
1317
1318  def testNonRepeatedScalar(self):
1319    # Nonrepeated scalar field change should cause inequality.
1320    self.first_proto.optional_int32 += 1
1321    self.assertNotEqual(self.first_proto, self.second_proto)
1322    # ...as should clearing a field.
1323    self.first_proto.ClearField('optional_int32')
1324    self.assertNotEqual(self.first_proto, self.second_proto)
1325
1326  def testNonRepeatedComposite(self):
1327    # Change a nonrepeated composite field.
1328    self.first_proto.optional_nested_message.bb += 1
1329    self.assertNotEqual(self.first_proto, self.second_proto)
1330    self.first_proto.optional_nested_message.bb -= 1
1331    self.assertEqual(self.first_proto, self.second_proto)
1332    # Clear a field in the nested message.
1333    self.first_proto.optional_nested_message.ClearField('bb')
1334    self.assertNotEqual(self.first_proto, self.second_proto)
1335    self.first_proto.optional_nested_message.bb = (
1336        self.second_proto.optional_nested_message.bb)
1337    self.assertEqual(self.first_proto, self.second_proto)
1338    # Remove the nested message entirely.
1339    self.first_proto.ClearField('optional_nested_message')
1340    self.assertNotEqual(self.first_proto, self.second_proto)
1341
1342  def testRepeatedScalar(self):
1343    # Change a repeated scalar field.
1344    self.first_proto.repeated_int32.append(5)
1345    self.assertNotEqual(self.first_proto, self.second_proto)
1346    self.first_proto.ClearField('repeated_int32')
1347    self.assertNotEqual(self.first_proto, self.second_proto)
1348
1349  def testRepeatedComposite(self):
1350    # Change value within a repeated composite field.
1351    self.first_proto.repeated_nested_message[0].bb += 1
1352    self.assertNotEqual(self.first_proto, self.second_proto)
1353    self.first_proto.repeated_nested_message[0].bb -= 1
1354    self.assertEqual(self.first_proto, self.second_proto)
1355    # Add a value to a repeated composite field.
1356    self.first_proto.repeated_nested_message.add()
1357    self.assertNotEqual(self.first_proto, self.second_proto)
1358    self.second_proto.repeated_nested_message.add()
1359    self.assertEqual(self.first_proto, self.second_proto)
1360
1361  def testNonRepeatedScalarHasBits(self):
1362    # Ensure that we test "has" bits as well as value for
1363    # nonrepeated scalar field.
1364    self.first_proto.ClearField('optional_int32')
1365    self.second_proto.optional_int32 = 0
1366    self.assertNotEqual(self.first_proto, self.second_proto)
1367
1368  def testNonRepeatedCompositeHasBits(self):
1369    # Ensure that we test "has" bits as well as value for
1370    # nonrepeated composite field.
1371    self.first_proto.ClearField('optional_nested_message')
1372    self.second_proto.optional_nested_message.ClearField('bb')
1373    self.assertNotEqual(self.first_proto, self.second_proto)
1374    # TODO(robinson): Replace next two lines with method
1375    # to set the "has" bit without changing the value,
1376    # if/when such a method exists.
1377    self.first_proto.optional_nested_message.bb = 0
1378    self.first_proto.optional_nested_message.ClearField('bb')
1379    self.assertEqual(self.first_proto, self.second_proto)
1380
1381
1382class ExtensionEqualityTest(unittest.TestCase):
1383
1384  def testExtensionEquality(self):
1385    first_proto = unittest_pb2.TestAllExtensions()
1386    second_proto = unittest_pb2.TestAllExtensions()
1387    self.assertEqual(first_proto, second_proto)
1388    test_util.SetAllExtensions(first_proto)
1389    self.assertNotEqual(first_proto, second_proto)
1390    test_util.SetAllExtensions(second_proto)
1391    self.assertEqual(first_proto, second_proto)
1392
1393    # Ensure that we check value equality.
1394    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
1395    self.assertNotEqual(first_proto, second_proto)
1396    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
1397    self.assertEqual(first_proto, second_proto)
1398
1399    # Ensure that we also look at "has" bits.
1400    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
1401    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1402    self.assertNotEqual(first_proto, second_proto)
1403    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1404    self.assertEqual(first_proto, second_proto)
1405
1406    # Ensure that differences in cached values
1407    # don't matter if "has" bits are both false.
1408    first_proto = unittest_pb2.TestAllExtensions()
1409    second_proto = unittest_pb2.TestAllExtensions()
1410    self.assertEqual(
1411        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
1412    self.assertEqual(first_proto, second_proto)
1413
1414
1415class MutualRecursionEqualityTest(unittest.TestCase):
1416
1417  def testEqualityWithMutualRecursion(self):
1418    first_proto = unittest_pb2.TestMutualRecursionA()
1419    second_proto = unittest_pb2.TestMutualRecursionA()
1420    self.assertEqual(first_proto, second_proto)
1421    first_proto.bb.a.bb.optional_int32 = 23
1422    self.assertNotEqual(first_proto, second_proto)
1423    second_proto.bb.a.bb.optional_int32 = 23
1424    self.assertEqual(first_proto, second_proto)
1425
1426
1427class ByteSizeTest(unittest.TestCase):
1428
1429  def setUp(self):
1430    self.proto = unittest_pb2.TestAllTypes()
1431    self.extended_proto = more_extensions_pb2.ExtendedMessage()
1432    self.packed_proto = unittest_pb2.TestPackedTypes()
1433    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
1434
1435  def Size(self):
1436    return self.proto.ByteSize()
1437
1438  def testEmptyMessage(self):
1439    self.assertEqual(0, self.proto.ByteSize())
1440
1441  def testVarints(self):
1442    def Test(i, expected_varint_size):
1443      self.proto.Clear()
1444      self.proto.optional_int64 = i
1445      # Add one to the varint size for the tag info
1446      # for tag 1.
1447      self.assertEqual(expected_varint_size + 1, self.Size())
1448    Test(0, 1)
1449    Test(1, 1)
1450    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
1451      Test((1 << i) - 1, num_bytes)
1452    Test(-1, 10)
1453    Test(-2, 10)
1454    Test(-(1 << 63), 10)
1455
1456  def testStrings(self):
1457    self.proto.optional_string = ''
1458    # Need one byte for tag info (tag #14), and one byte for length.
1459    self.assertEqual(2, self.Size())
1460
1461    self.proto.optional_string = 'abc'
1462    # Need one byte for tag info (tag #14), and one byte for length.
1463    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
1464
1465    self.proto.optional_string = 'x' * 128
1466    # Need one byte for tag info (tag #14), and TWO bytes for length.
1467    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
1468
1469  def testOtherNumerics(self):
1470    self.proto.optional_fixed32 = 1234
1471    # One byte for tag and 4 bytes for fixed32.
1472    self.assertEqual(5, self.Size())
1473    self.proto = unittest_pb2.TestAllTypes()
1474
1475    self.proto.optional_fixed64 = 1234
1476    # One byte for tag and 8 bytes for fixed64.
1477    self.assertEqual(9, self.Size())
1478    self.proto = unittest_pb2.TestAllTypes()
1479
1480    self.proto.optional_float = 1.234
1481    # One byte for tag and 4 bytes for float.
1482    self.assertEqual(5, self.Size())
1483    self.proto = unittest_pb2.TestAllTypes()
1484
1485    self.proto.optional_double = 1.234
1486    # One byte for tag and 8 bytes for float.
1487    self.assertEqual(9, self.Size())
1488    self.proto = unittest_pb2.TestAllTypes()
1489
1490    self.proto.optional_sint32 = 64
1491    # One byte for tag and 2 bytes for zig-zag-encoded 64.
1492    self.assertEqual(3, self.Size())
1493    self.proto = unittest_pb2.TestAllTypes()
1494
1495  def testComposites(self):
1496    # 3 bytes.
1497    self.proto.optional_nested_message.bb = (1 << 14)
1498    # Plus one byte for bb tag.
1499    # Plus 1 byte for optional_nested_message serialized size.
1500    # Plus two bytes for optional_nested_message tag.
1501    self.assertEqual(3 + 1 + 1 + 2, self.Size())
1502
1503  def testGroups(self):
1504    # 4 bytes.
1505    self.proto.optionalgroup.a = (1 << 21)
1506    # Plus two bytes for |a| tag.
1507    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
1508    self.assertEqual(4 + 2 + 2*2, self.Size())
1509
1510  def testRepeatedScalars(self):
1511    self.proto.repeated_int32.append(10)  # 1 byte.
1512    self.proto.repeated_int32.append(128)  # 2 bytes.
1513    # Also need 2 bytes for each entry for tag.
1514    self.assertEqual(1 + 2 + 2*2, self.Size())
1515
1516  def testRepeatedScalarsExtend(self):
1517    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
1518    # Also need 2 bytes for each entry for tag.
1519    self.assertEqual(1 + 2 + 2*2, self.Size())
1520
1521  def testRepeatedScalarsRemove(self):
1522    self.proto.repeated_int32.append(10)  # 1 byte.
1523    self.proto.repeated_int32.append(128)  # 2 bytes.
1524    # Also need 2 bytes for each entry for tag.
1525    self.assertEqual(1 + 2 + 2*2, self.Size())
1526    self.proto.repeated_int32.remove(128)
1527    self.assertEqual(1 + 2, self.Size())
1528
1529  def testRepeatedComposites(self):
1530    # Empty message.  2 bytes tag plus 1 byte length.
1531    foreign_message_0 = self.proto.repeated_nested_message.add()
1532    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1533    foreign_message_1 = self.proto.repeated_nested_message.add()
1534    foreign_message_1.bb = 7
1535    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
1536
1537  def testRepeatedCompositesDelete(self):
1538    # Empty message.  2 bytes tag plus 1 byte length.
1539    foreign_message_0 = self.proto.repeated_nested_message.add()
1540    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1541    foreign_message_1 = self.proto.repeated_nested_message.add()
1542    foreign_message_1.bb = 9
1543    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
1544
1545    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1546    del self.proto.repeated_nested_message[0]
1547    self.assertEqual(2 + 1 + 1 + 1, self.Size())
1548
1549    # Now add a new message.
1550    foreign_message_2 = self.proto.repeated_nested_message.add()
1551    foreign_message_2.bb = 12
1552
1553    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1554    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1555    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
1556
1557    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1558    del self.proto.repeated_nested_message[1]
1559    self.assertEqual(2 + 1 + 1 + 1, self.Size())
1560
1561    del self.proto.repeated_nested_message[0]
1562    self.assertEqual(0, self.Size())
1563
1564  def testRepeatedGroups(self):
1565    # 2-byte START_GROUP plus 2-byte END_GROUP.
1566    group_0 = self.proto.repeatedgroup.add()
1567    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
1568    # plus 2-byte END_GROUP.
1569    group_1 = self.proto.repeatedgroup.add()
1570    group_1.a =  7
1571    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
1572
1573  def testExtensions(self):
1574    proto = unittest_pb2.TestAllExtensions()
1575    self.assertEqual(0, proto.ByteSize())
1576    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
1577    proto.Extensions[extension] = 23
1578    # 1 byte for tag, 1 byte for value.
1579    self.assertEqual(2, proto.ByteSize())
1580
1581  def testCacheInvalidationForNonrepeatedScalar(self):
1582    # Test non-extension.
1583    self.proto.optional_int32 = 1
1584    self.assertEqual(2, self.proto.ByteSize())
1585    self.proto.optional_int32 = 128
1586    self.assertEqual(3, self.proto.ByteSize())
1587    self.proto.ClearField('optional_int32')
1588    self.assertEqual(0, self.proto.ByteSize())
1589
1590    # Test within extension.
1591    extension = more_extensions_pb2.optional_int_extension
1592    self.extended_proto.Extensions[extension] = 1
1593    self.assertEqual(2, self.extended_proto.ByteSize())
1594    self.extended_proto.Extensions[extension] = 128
1595    self.assertEqual(3, self.extended_proto.ByteSize())
1596    self.extended_proto.ClearExtension(extension)
1597    self.assertEqual(0, self.extended_proto.ByteSize())
1598
1599  def testCacheInvalidationForRepeatedScalar(self):
1600    # Test non-extension.
1601    self.proto.repeated_int32.append(1)
1602    self.assertEqual(3, self.proto.ByteSize())
1603    self.proto.repeated_int32.append(1)
1604    self.assertEqual(6, self.proto.ByteSize())
1605    self.proto.repeated_int32[1] = 128
1606    self.assertEqual(7, self.proto.ByteSize())
1607    self.proto.ClearField('repeated_int32')
1608    self.assertEqual(0, self.proto.ByteSize())
1609
1610    # Test within extension.
1611    extension = more_extensions_pb2.repeated_int_extension
1612    repeated = self.extended_proto.Extensions[extension]
1613    repeated.append(1)
1614    self.assertEqual(2, self.extended_proto.ByteSize())
1615    repeated.append(1)
1616    self.assertEqual(4, self.extended_proto.ByteSize())
1617    repeated[1] = 128
1618    self.assertEqual(5, self.extended_proto.ByteSize())
1619    self.extended_proto.ClearExtension(extension)
1620    self.assertEqual(0, self.extended_proto.ByteSize())
1621
1622  def testCacheInvalidationForNonrepeatedMessage(self):
1623    # Test non-extension.
1624    self.proto.optional_foreign_message.c = 1
1625    self.assertEqual(5, self.proto.ByteSize())
1626    self.proto.optional_foreign_message.c = 128
1627    self.assertEqual(6, self.proto.ByteSize())
1628    self.proto.optional_foreign_message.ClearField('c')
1629    self.assertEqual(3, self.proto.ByteSize())
1630    self.proto.ClearField('optional_foreign_message')
1631    self.assertEqual(0, self.proto.ByteSize())
1632    child = self.proto.optional_foreign_message
1633    self.proto.ClearField('optional_foreign_message')
1634    child.c = 128
1635    self.assertEqual(0, self.proto.ByteSize())
1636
1637    # Test within extension.
1638    extension = more_extensions_pb2.optional_message_extension
1639    child = self.extended_proto.Extensions[extension]
1640    self.assertEqual(0, self.extended_proto.ByteSize())
1641    child.foreign_message_int = 1
1642    self.assertEqual(4, self.extended_proto.ByteSize())
1643    child.foreign_message_int = 128
1644    self.assertEqual(5, self.extended_proto.ByteSize())
1645    self.extended_proto.ClearExtension(extension)
1646    self.assertEqual(0, self.extended_proto.ByteSize())
1647
1648  def testCacheInvalidationForRepeatedMessage(self):
1649    # Test non-extension.
1650    child0 = self.proto.repeated_foreign_message.add()
1651    self.assertEqual(3, self.proto.ByteSize())
1652    self.proto.repeated_foreign_message.add()
1653    self.assertEqual(6, self.proto.ByteSize())
1654    child0.c = 1
1655    self.assertEqual(8, self.proto.ByteSize())
1656    self.proto.ClearField('repeated_foreign_message')
1657    self.assertEqual(0, self.proto.ByteSize())
1658
1659    # Test within extension.
1660    extension = more_extensions_pb2.repeated_message_extension
1661    child_list = self.extended_proto.Extensions[extension]
1662    child0 = child_list.add()
1663    self.assertEqual(2, self.extended_proto.ByteSize())
1664    child_list.add()
1665    self.assertEqual(4, self.extended_proto.ByteSize())
1666    child0.foreign_message_int = 1
1667    self.assertEqual(6, self.extended_proto.ByteSize())
1668    child0.ClearField('foreign_message_int')
1669    self.assertEqual(4, self.extended_proto.ByteSize())
1670    self.extended_proto.ClearExtension(extension)
1671    self.assertEqual(0, self.extended_proto.ByteSize())
1672
1673  def testPackedRepeatedScalars(self):
1674    self.assertEqual(0, self.packed_proto.ByteSize())
1675
1676    self.packed_proto.packed_int32.append(10)   # 1 byte.
1677    self.packed_proto.packed_int32.append(128)  # 2 bytes.
1678    # The tag is 2 bytes (the field number is 90), and the varint
1679    # storing the length is 1 byte.
1680    int_size = 1 + 2 + 3
1681    self.assertEqual(int_size, self.packed_proto.ByteSize())
1682
1683    self.packed_proto.packed_double.append(4.2)   # 8 bytes
1684    self.packed_proto.packed_double.append(3.25)  # 8 bytes
1685    # 2 more tag bytes, 1 more length byte.
1686    double_size = 8 + 8 + 3
1687    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
1688
1689    self.packed_proto.ClearField('packed_int32')
1690    self.assertEqual(double_size, self.packed_proto.ByteSize())
1691
1692  def testPackedExtensions(self):
1693    self.assertEqual(0, self.packed_extended_proto.ByteSize())
1694    extension = self.packed_extended_proto.Extensions[
1695        unittest_pb2.packed_fixed32_extension]
1696    extension.extend([1, 2, 3, 4])   # 16 bytes
1697    # Tag is 3 bytes.
1698    self.assertEqual(19, self.packed_extended_proto.ByteSize())
1699
1700
1701# TODO(robinson): We need cross-language serialization consistency tests.
1702# Issues to be sure to cover include:
1703#   * Handling of unrecognized tags ("uninterpreted_bytes").
1704#   * Handling of MessageSets.
1705#   * Consistent ordering of tags in the wire format,
1706#     including ordering between extensions and non-extension
1707#     fields.
1708#   * Consistent serialization of negative numbers, especially
1709#     negative int32s.
1710#   * Handling of empty submessages (with and without "has"
1711#     bits set).
1712
1713class SerializationTest(unittest.TestCase):
1714
1715  def testSerializeEmtpyMessage(self):
1716    first_proto = unittest_pb2.TestAllTypes()
1717    second_proto = unittest_pb2.TestAllTypes()
1718    serialized = first_proto.SerializeToString()
1719    self.assertEqual(first_proto.ByteSize(), len(serialized))
1720    second_proto.MergeFromString(serialized)
1721    self.assertEqual(first_proto, second_proto)
1722
1723  def testSerializeAllFields(self):
1724    first_proto = unittest_pb2.TestAllTypes()
1725    second_proto = unittest_pb2.TestAllTypes()
1726    test_util.SetAllFields(first_proto)
1727    serialized = first_proto.SerializeToString()
1728    self.assertEqual(first_proto.ByteSize(), len(serialized))
1729    second_proto.MergeFromString(serialized)
1730    self.assertEqual(first_proto, second_proto)
1731
1732  def testSerializeAllExtensions(self):
1733    first_proto = unittest_pb2.TestAllExtensions()
1734    second_proto = unittest_pb2.TestAllExtensions()
1735    test_util.SetAllExtensions(first_proto)
1736    serialized = first_proto.SerializeToString()
1737    second_proto.MergeFromString(serialized)
1738    self.assertEqual(first_proto, second_proto)
1739
1740  def testSerializeNegativeValues(self):
1741    first_proto = unittest_pb2.TestAllTypes()
1742
1743    first_proto.optional_int32 = -1
1744    first_proto.optional_int64 = -(2 << 40)
1745    first_proto.optional_sint32 = -3
1746    first_proto.optional_sint64 = -(4 << 40)
1747    first_proto.optional_sfixed32 = -5
1748    first_proto.optional_sfixed64 = -(6 << 40)
1749
1750    second_proto = unittest_pb2.TestAllTypes.FromString(
1751        first_proto.SerializeToString())
1752
1753    self.assertEqual(first_proto, second_proto)
1754
1755  def testParseTruncated(self):
1756    first_proto = unittest_pb2.TestAllTypes()
1757    test_util.SetAllFields(first_proto)
1758    serialized = first_proto.SerializeToString()
1759
1760    for truncation_point in xrange(len(serialized) + 1):
1761      try:
1762        second_proto = unittest_pb2.TestAllTypes()
1763        unknown_fields = unittest_pb2.TestEmptyMessage()
1764        pos = second_proto._InternalParse(serialized, 0, truncation_point)
1765        # If we didn't raise an error then we read exactly the amount expected.
1766        self.assertEqual(truncation_point, pos)
1767
1768        # Parsing to unknown fields should not throw if parsing to known fields
1769        # did not.
1770        try:
1771          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
1772          self.assertEqual(truncation_point, pos2)
1773        except message.DecodeError:
1774          self.fail('Parsing unknown fields failed when parsing known fields '
1775                    'did not.')
1776      except message.DecodeError:
1777        # Parsing unknown fields should also fail.
1778        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
1779                          serialized, 0, truncation_point)
1780
1781  def testCanonicalSerializationOrder(self):
1782    proto = more_messages_pb2.OutOfOrderFields()
1783    # These are also their tag numbers.  Even though we're setting these in
1784    # reverse-tag order AND they're listed in reverse tag-order in the .proto
1785    # file, they should nonetheless be serialized in tag order.
1786    proto.optional_sint32 = 5
1787    proto.Extensions[more_messages_pb2.optional_uint64] = 4
1788    proto.optional_uint32 = 3
1789    proto.Extensions[more_messages_pb2.optional_int64] = 2
1790    proto.optional_int32 = 1
1791    serialized = proto.SerializeToString()
1792    self.assertEqual(proto.ByteSize(), len(serialized))
1793    d = _MiniDecoder(serialized)
1794    ReadTag = d.ReadFieldNumberAndWireType
1795    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
1796    self.assertEqual(1, d.ReadInt32())
1797    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
1798    self.assertEqual(2, d.ReadInt64())
1799    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
1800    self.assertEqual(3, d.ReadUInt32())
1801    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
1802    self.assertEqual(4, d.ReadUInt64())
1803    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
1804    self.assertEqual(5, d.ReadSInt32())
1805
1806  def testCanonicalSerializationOrderSameAsCpp(self):
1807    # Copy of the same test we use for C++.
1808    proto = unittest_pb2.TestFieldOrderings()
1809    test_util.SetAllFieldsAndExtensions(proto)
1810    serialized = proto.SerializeToString()
1811    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
1812
1813  def testMergeFromStringWhenFieldsAlreadySet(self):
1814    first_proto = unittest_pb2.TestAllTypes()
1815    first_proto.repeated_string.append('foobar')
1816    first_proto.optional_int32 = 23
1817    first_proto.optional_nested_message.bb = 42
1818    serialized = first_proto.SerializeToString()
1819
1820    second_proto = unittest_pb2.TestAllTypes()
1821    second_proto.repeated_string.append('baz')
1822    second_proto.optional_int32 = 100
1823    second_proto.optional_nested_message.bb = 999
1824
1825    second_proto.MergeFromString(serialized)
1826    # Ensure that we append to repeated fields.
1827    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
1828    # Ensure that we overwrite nonrepeatd scalars.
1829    self.assertEqual(23, second_proto.optional_int32)
1830    # Ensure that we recursively call MergeFromString() on
1831    # submessages.
1832    self.assertEqual(42, second_proto.optional_nested_message.bb)
1833
1834  def testMessageSetWireFormat(self):
1835    proto = unittest_mset_pb2.TestMessageSet()
1836    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1837    extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
1838    extension1 = extension_message1.message_set_extension
1839    extension2 = extension_message2.message_set_extension
1840    proto.Extensions[extension1].i = 123
1841    proto.Extensions[extension2].str = 'foo'
1842
1843    # Serialize using the MessageSet wire format (this is specified in the
1844    # .proto file).
1845    serialized = proto.SerializeToString()
1846
1847    raw = unittest_mset_pb2.RawMessageSet()
1848    self.assertEqual(False,
1849                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
1850    raw.MergeFromString(serialized)
1851    self.assertEqual(2, len(raw.item))
1852
1853    message1 = unittest_mset_pb2.TestMessageSetExtension1()
1854    message1.MergeFromString(raw.item[0].message)
1855    self.assertEqual(123, message1.i)
1856
1857    message2 = unittest_mset_pb2.TestMessageSetExtension2()
1858    message2.MergeFromString(raw.item[1].message)
1859    self.assertEqual('foo', message2.str)
1860
1861    # Deserialize using the MessageSet wire format.
1862    proto2 = unittest_mset_pb2.TestMessageSet()
1863    proto2.MergeFromString(serialized)
1864    self.assertEqual(123, proto2.Extensions[extension1].i)
1865    self.assertEqual('foo', proto2.Extensions[extension2].str)
1866
1867    # Check byte size.
1868    self.assertEqual(proto2.ByteSize(), len(serialized))
1869    self.assertEqual(proto.ByteSize(), len(serialized))
1870
1871  def testMessageSetWireFormatUnknownExtension(self):
1872    # Create a message using the message set wire format with an unknown
1873    # message.
1874    raw = unittest_mset_pb2.RawMessageSet()
1875
1876    # Add an item.
1877    item = raw.item.add()
1878    item.type_id = 1545008
1879    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1880    message1 = unittest_mset_pb2.TestMessageSetExtension1()
1881    message1.i = 12345
1882    item.message = message1.SerializeToString()
1883
1884    # Add a second, unknown extension.
1885    item = raw.item.add()
1886    item.type_id = 1545009
1887    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1888    message1 = unittest_mset_pb2.TestMessageSetExtension1()
1889    message1.i = 12346
1890    item.message = message1.SerializeToString()
1891
1892    # Add another unknown extension.
1893    item = raw.item.add()
1894    item.type_id = 1545010
1895    message1 = unittest_mset_pb2.TestMessageSetExtension2()
1896    message1.str = 'foo'
1897    item.message = message1.SerializeToString()
1898
1899    serialized = raw.SerializeToString()
1900
1901    # Parse message using the message set wire format.
1902    proto = unittest_mset_pb2.TestMessageSet()
1903    proto.MergeFromString(serialized)
1904
1905    # Check that the message parsed well.
1906    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1907    extension1 = extension_message1.message_set_extension
1908    self.assertEquals(12345, proto.Extensions[extension1].i)
1909
1910  def testUnknownFields(self):
1911    proto = unittest_pb2.TestAllTypes()
1912    test_util.SetAllFields(proto)
1913
1914    serialized = proto.SerializeToString()
1915
1916    # The empty message should be parsable with all of the fields
1917    # unknown.
1918    proto2 = unittest_pb2.TestEmptyMessage()
1919
1920    # Parsing this message should succeed.
1921    proto2.MergeFromString(serialized)
1922
1923    # Now test with a int64 field set.
1924    proto = unittest_pb2.TestAllTypes()
1925    proto.optional_int64 = 0x0fffffffffffffff
1926    serialized = proto.SerializeToString()
1927    # The empty message should be parsable with all of the fields
1928    # unknown.
1929    proto2 = unittest_pb2.TestEmptyMessage()
1930    # Parsing this message should succeed.
1931    proto2.MergeFromString(serialized)
1932
1933  def _CheckRaises(self, exc_class, callable_obj, exception):
1934    """This method checks if the excpetion type and message are as expected."""
1935    try:
1936      callable_obj()
1937    except exc_class, ex:
1938      # Check if the exception message is the right one.
1939      self.assertEqual(exception, str(ex))
1940      return
1941    else:
1942      raise self.failureException('%s not raised' % str(exc_class))
1943
1944  def testSerializeUninitialized(self):
1945    proto = unittest_pb2.TestRequired()
1946    self._CheckRaises(
1947        message.EncodeError,
1948        proto.SerializeToString,
1949        'Message is missing required fields: a,b,c')
1950    # Shouldn't raise exceptions.
1951    partial = proto.SerializePartialToString()
1952
1953    proto.a = 1
1954    self._CheckRaises(
1955        message.EncodeError,
1956        proto.SerializeToString,
1957        'Message is missing required fields: b,c')
1958    # Shouldn't raise exceptions.
1959    partial = proto.SerializePartialToString()
1960
1961    proto.b = 2
1962    self._CheckRaises(
1963        message.EncodeError,
1964        proto.SerializeToString,
1965        'Message is missing required fields: c')
1966    # Shouldn't raise exceptions.
1967    partial = proto.SerializePartialToString()
1968
1969    proto.c = 3
1970    serialized = proto.SerializeToString()
1971    # Shouldn't raise exceptions.
1972    partial = proto.SerializePartialToString()
1973
1974    proto2 = unittest_pb2.TestRequired()
1975    proto2.MergeFromString(serialized)
1976    self.assertEqual(1, proto2.a)
1977    self.assertEqual(2, proto2.b)
1978    self.assertEqual(3, proto2.c)
1979    proto2.ParseFromString(partial)
1980    self.assertEqual(1, proto2.a)
1981    self.assertEqual(2, proto2.b)
1982    self.assertEqual(3, proto2.c)
1983
1984  def testSerializeUninitializedSubMessage(self):
1985    proto = unittest_pb2.TestRequiredForeign()
1986
1987    # Sub-message doesn't exist yet, so this succeeds.
1988    proto.SerializeToString()
1989
1990    proto.optional_message.a = 1
1991    self._CheckRaises(
1992        message.EncodeError,
1993        proto.SerializeToString,
1994        'Message is missing required fields: '
1995        'optional_message.b,optional_message.c')
1996
1997    proto.optional_message.b = 2
1998    proto.optional_message.c = 3
1999    proto.SerializeToString()
2000
2001    proto.repeated_message.add().a = 1
2002    proto.repeated_message.add().b = 2
2003    self._CheckRaises(
2004        message.EncodeError,
2005        proto.SerializeToString,
2006        'Message is missing required fields: '
2007        'repeated_message[0].b,repeated_message[0].c,'
2008        'repeated_message[1].a,repeated_message[1].c')
2009
2010    proto.repeated_message[0].b = 2
2011    proto.repeated_message[0].c = 3
2012    proto.repeated_message[1].a = 1
2013    proto.repeated_message[1].c = 3
2014    proto.SerializeToString()
2015
2016  def testSerializeAllPackedFields(self):
2017    first_proto = unittest_pb2.TestPackedTypes()
2018    second_proto = unittest_pb2.TestPackedTypes()
2019    test_util.SetAllPackedFields(first_proto)
2020    serialized = first_proto.SerializeToString()
2021    self.assertEqual(first_proto.ByteSize(), len(serialized))
2022    bytes_read = second_proto.MergeFromString(serialized)
2023    self.assertEqual(second_proto.ByteSize(), bytes_read)
2024    self.assertEqual(first_proto, second_proto)
2025
2026  def testSerializeAllPackedExtensions(self):
2027    first_proto = unittest_pb2.TestPackedExtensions()
2028    second_proto = unittest_pb2.TestPackedExtensions()
2029    test_util.SetAllPackedExtensions(first_proto)
2030    serialized = first_proto.SerializeToString()
2031    bytes_read = second_proto.MergeFromString(serialized)
2032    self.assertEqual(second_proto.ByteSize(), bytes_read)
2033    self.assertEqual(first_proto, second_proto)
2034
2035  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2036    first_proto = unittest_pb2.TestPackedTypes()
2037    first_proto.packed_int32.extend([1, 2])
2038    first_proto.packed_double.append(3.0)
2039    serialized = first_proto.SerializeToString()
2040
2041    second_proto = unittest_pb2.TestPackedTypes()
2042    second_proto.packed_int32.append(3)
2043    second_proto.packed_double.extend([1.0, 2.0])
2044    second_proto.packed_sint32.append(4)
2045
2046    second_proto.MergeFromString(serialized)
2047    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2048    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2049    self.assertEqual([4], second_proto.packed_sint32)
2050
2051  def testPackedFieldsWireFormat(self):
2052    proto = unittest_pb2.TestPackedTypes()
2053    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2054    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2055    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2056    serialized = proto.SerializeToString()
2057    self.assertEqual(proto.ByteSize(), len(serialized))
2058    d = _MiniDecoder(serialized)
2059    ReadTag = d.ReadFieldNumberAndWireType
2060    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2061    self.assertEqual(1+1+1+2, d.ReadInt32())
2062    self.assertEqual(1, d.ReadInt32())
2063    self.assertEqual(2, d.ReadInt32())
2064    self.assertEqual(150, d.ReadInt32())
2065    self.assertEqual(3, d.ReadInt32())
2066    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2067    self.assertEqual(4, d.ReadInt32())
2068    self.assertEqual(2.0, d.ReadFloat())
2069    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2070    self.assertEqual(8+8, d.ReadInt32())
2071    self.assertEqual(1.0, d.ReadDouble())
2072    self.assertEqual(1000.0, d.ReadDouble())
2073    self.assertTrue(d.EndOfStream())
2074
2075  def testParsePackedFromUnpacked(self):
2076    unpacked = unittest_pb2.TestUnpackedTypes()
2077    test_util.SetAllUnpackedFields(unpacked)
2078    packed = unittest_pb2.TestPackedTypes()
2079    packed.MergeFromString(unpacked.SerializeToString())
2080    expected = unittest_pb2.TestPackedTypes()
2081    test_util.SetAllPackedFields(expected)
2082    self.assertEqual(expected, packed)
2083
2084  def testParseUnpackedFromPacked(self):
2085    packed = unittest_pb2.TestPackedTypes()
2086    test_util.SetAllPackedFields(packed)
2087    unpacked = unittest_pb2.TestUnpackedTypes()
2088    unpacked.MergeFromString(packed.SerializeToString())
2089    expected = unittest_pb2.TestUnpackedTypes()
2090    test_util.SetAllUnpackedFields(expected)
2091    self.assertEqual(expected, unpacked)
2092
2093  def testFieldNumbers(self):
2094    proto = unittest_pb2.TestAllTypes()
2095    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2096    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2097    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
2098    self.assertEqual(
2099      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
2100    self.assertEqual(
2101      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
2102    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
2103    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
2104    self.assertEqual(
2105      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
2106    self.assertEqual(
2107      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
2108
2109  def testExtensionFieldNumbers(self):
2110    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
2111    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
2112    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
2113    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
2114    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
2115    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
2116    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
2117    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
2118    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
2119    self.assertEqual(
2120      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
2121    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
2122    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2123      21)
2124    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
2125    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
2126    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
2127    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
2128    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
2129    self.assertEqual(
2130      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
2131    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
2132    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2133      51)
2134
2135  def testInitKwargs(self):
2136    proto = unittest_pb2.TestAllTypes(
2137        optional_int32=1,
2138        optional_string='foo',
2139        optional_bool=True,
2140        optional_bytes='bar',
2141        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
2142        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
2143        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
2144        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
2145        repeated_int32=[1, 2, 3])
2146    self.assertTrue(proto.IsInitialized())
2147    self.assertTrue(proto.HasField('optional_int32'))
2148    self.assertTrue(proto.HasField('optional_string'))
2149    self.assertTrue(proto.HasField('optional_bool'))
2150    self.assertTrue(proto.HasField('optional_bytes'))
2151    self.assertTrue(proto.HasField('optional_nested_message'))
2152    self.assertTrue(proto.HasField('optional_foreign_message'))
2153    self.assertTrue(proto.HasField('optional_nested_enum'))
2154    self.assertTrue(proto.HasField('optional_foreign_enum'))
2155    self.assertEqual(1, proto.optional_int32)
2156    self.assertEqual('foo', proto.optional_string)
2157    self.assertEqual(True, proto.optional_bool)
2158    self.assertEqual('bar', proto.optional_bytes)
2159    self.assertEqual(1, proto.optional_nested_message.bb)
2160    self.assertEqual(1, proto.optional_foreign_message.c)
2161    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
2162                     proto.optional_nested_enum)
2163    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
2164    self.assertEqual([1, 2, 3], proto.repeated_int32)
2165
2166  def testInitArgsUnknownFieldName(self):
2167    def InitalizeEmptyMessageWithExtraKeywordArg():
2168      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
2169    self._CheckRaises(ValueError,
2170                      InitalizeEmptyMessageWithExtraKeywordArg,
2171                      'Protocol message has no "unknown" field.')
2172
2173  def testInitRequiredKwargs(self):
2174    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
2175    self.assertTrue(proto.IsInitialized())
2176    self.assertTrue(proto.HasField('a'))
2177    self.assertTrue(proto.HasField('b'))
2178    self.assertTrue(proto.HasField('c'))
2179    self.assertTrue(not proto.HasField('dummy2'))
2180    self.assertEqual(1, proto.a)
2181    self.assertEqual(1, proto.b)
2182    self.assertEqual(1, proto.c)
2183
2184  def testInitRequiredForeignKwargs(self):
2185    proto = unittest_pb2.TestRequiredForeign(
2186        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
2187    self.assertTrue(proto.IsInitialized())
2188    self.assertTrue(proto.HasField('optional_message'))
2189    self.assertTrue(proto.optional_message.IsInitialized())
2190    self.assertTrue(proto.optional_message.HasField('a'))
2191    self.assertTrue(proto.optional_message.HasField('b'))
2192    self.assertTrue(proto.optional_message.HasField('c'))
2193    self.assertTrue(not proto.optional_message.HasField('dummy2'))
2194    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
2195                     proto.optional_message)
2196    self.assertEqual(1, proto.optional_message.a)
2197    self.assertEqual(1, proto.optional_message.b)
2198    self.assertEqual(1, proto.optional_message.c)
2199
2200  def testInitRepeatedKwargs(self):
2201    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
2202    self.assertTrue(proto.IsInitialized())
2203    self.assertEqual(1, proto.repeated_int32[0])
2204    self.assertEqual(2, proto.repeated_int32[1])
2205    self.assertEqual(3, proto.repeated_int32[2])
2206
2207
2208class OptionsTest(unittest.TestCase):
2209
2210  def testMessageOptions(self):
2211    proto = unittest_mset_pb2.TestMessageSet()
2212    self.assertEqual(True,
2213                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2214    proto = unittest_pb2.TestAllTypes()
2215    self.assertEqual(False,
2216                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2217
2218  def testPackedOptions(self):
2219    proto = unittest_pb2.TestAllTypes()
2220    proto.optional_int32 = 1
2221    proto.optional_double = 3.0
2222    for field_descriptor, _ in proto.ListFields():
2223      self.assertEqual(False, field_descriptor.GetOptions().packed)
2224
2225    proto = unittest_pb2.TestPackedTypes()
2226    proto.packed_int32.append(1)
2227    proto.packed_double.append(3.0)
2228    for field_descriptor, _ in proto.ListFields():
2229      self.assertEqual(True, field_descriptor.GetOptions().packed)
2230      self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
2231                       field_descriptor.label)
2232
2233
2234
2235if __name__ == '__main__':
2236  unittest.main()
2237