1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# http://code.google.com/p/protobuf/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38__author__ = 'robinson@google.com (Will Robinson)'
39
40import gc
41import operator
42import struct
43
44import unittest
45from google.protobuf import unittest_import_pb2
46from google.protobuf import unittest_mset_pb2
47from google.protobuf import unittest_pb2
48from google.protobuf import descriptor_pb2
49from google.protobuf import descriptor
50from google.protobuf import message
51from google.protobuf import reflection
52from google.protobuf.internal import api_implementation
53from google.protobuf.internal import more_extensions_pb2
54from google.protobuf.internal import more_messages_pb2
55from google.protobuf.internal import wire_format
56from google.protobuf.internal import test_util
57from google.protobuf.internal import decoder
58
59
60class _MiniDecoder(object):
61  """Decodes a stream of values from a string.
62
63  Once upon a time we actually had a class called decoder.Decoder.  Then we
64  got rid of it during a redesign that made decoding much, much faster overall.
65  But a couple tests in this file used it to check that the serialized form of
66  a message was correct.  So, this class implements just the methods that were
67  used by said tests, so that we don't have to rewrite the tests.
68  """
69
70  def __init__(self, bytes):
71    self._bytes = bytes
72    self._pos = 0
73
74  def ReadVarint(self):
75    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
76    return result
77
78  ReadInt32 = ReadVarint
79  ReadInt64 = ReadVarint
80  ReadUInt32 = ReadVarint
81  ReadUInt64 = ReadVarint
82
83  def ReadSInt64(self):
84    return wire_format.ZigZagDecode(self.ReadVarint())
85
86  ReadSInt32 = ReadSInt64
87
88  def ReadFieldNumberAndWireType(self):
89    return wire_format.UnpackTag(self.ReadVarint())
90
91  def ReadFloat(self):
92    result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
93    self._pos += 4
94    return result
95
96  def ReadDouble(self):
97    result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
98    self._pos += 8
99    return result
100
101  def EndOfStream(self):
102    return self._pos == len(self._bytes)
103
104
105class ReflectionTest(unittest.TestCase):
106
107  def assertListsEqual(self, values, others):
108    self.assertEqual(len(values), len(others))
109    for i in range(len(values)):
110      self.assertEqual(values[i], others[i])
111
112  def testScalarConstructor(self):
113    # Constructor with only scalar types should succeed.
114    proto = unittest_pb2.TestAllTypes(
115        optional_int32=24,
116        optional_double=54.321,
117        optional_string='optional_string')
118
119    self.assertEqual(24, proto.optional_int32)
120    self.assertEqual(54.321, proto.optional_double)
121    self.assertEqual('optional_string', proto.optional_string)
122
123  def testRepeatedScalarConstructor(self):
124    # Constructor with only repeated scalar types should succeed.
125    proto = unittest_pb2.TestAllTypes(
126        repeated_int32=[1, 2, 3, 4],
127        repeated_double=[1.23, 54.321],
128        repeated_bool=[True, False, False],
129        repeated_string=["optional_string"])
130
131    self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
132    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
133    self.assertEquals([True, False, False], list(proto.repeated_bool))
134    self.assertEquals(["optional_string"], list(proto.repeated_string))
135
136  def testRepeatedCompositeConstructor(self):
137    # Constructor with only repeated composite types should succeed.
138    proto = unittest_pb2.TestAllTypes(
139        repeated_nested_message=[
140            unittest_pb2.TestAllTypes.NestedMessage(
141                bb=unittest_pb2.TestAllTypes.FOO),
142            unittest_pb2.TestAllTypes.NestedMessage(
143                bb=unittest_pb2.TestAllTypes.BAR)],
144        repeated_foreign_message=[
145            unittest_pb2.ForeignMessage(c=-43),
146            unittest_pb2.ForeignMessage(c=45324),
147            unittest_pb2.ForeignMessage(c=12)],
148        repeatedgroup=[
149            unittest_pb2.TestAllTypes.RepeatedGroup(),
150            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
151            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
152
153    self.assertEquals(
154        [unittest_pb2.TestAllTypes.NestedMessage(
155            bb=unittest_pb2.TestAllTypes.FOO),
156         unittest_pb2.TestAllTypes.NestedMessage(
157             bb=unittest_pb2.TestAllTypes.BAR)],
158        list(proto.repeated_nested_message))
159    self.assertEquals(
160        [unittest_pb2.ForeignMessage(c=-43),
161         unittest_pb2.ForeignMessage(c=45324),
162         unittest_pb2.ForeignMessage(c=12)],
163        list(proto.repeated_foreign_message))
164    self.assertEquals(
165        [unittest_pb2.TestAllTypes.RepeatedGroup(),
166         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
167         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
168        list(proto.repeatedgroup))
169
170  def testMixedConstructor(self):
171    # Constructor with only mixed types should succeed.
172    proto = unittest_pb2.TestAllTypes(
173        optional_int32=24,
174        optional_string='optional_string',
175        repeated_double=[1.23, 54.321],
176        repeated_bool=[True, False, False],
177        repeated_nested_message=[
178            unittest_pb2.TestAllTypes.NestedMessage(
179                bb=unittest_pb2.TestAllTypes.FOO),
180            unittest_pb2.TestAllTypes.NestedMessage(
181                bb=unittest_pb2.TestAllTypes.BAR)],
182        repeated_foreign_message=[
183            unittest_pb2.ForeignMessage(c=-43),
184            unittest_pb2.ForeignMessage(c=45324),
185            unittest_pb2.ForeignMessage(c=12)])
186
187    self.assertEqual(24, proto.optional_int32)
188    self.assertEqual('optional_string', proto.optional_string)
189    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
190    self.assertEquals([True, False, False], list(proto.repeated_bool))
191    self.assertEquals(
192        [unittest_pb2.TestAllTypes.NestedMessage(
193            bb=unittest_pb2.TestAllTypes.FOO),
194         unittest_pb2.TestAllTypes.NestedMessage(
195             bb=unittest_pb2.TestAllTypes.BAR)],
196        list(proto.repeated_nested_message))
197    self.assertEquals(
198        [unittest_pb2.ForeignMessage(c=-43),
199         unittest_pb2.ForeignMessage(c=45324),
200         unittest_pb2.ForeignMessage(c=12)],
201        list(proto.repeated_foreign_message))
202
203  def testConstructorTypeError(self):
204    self.assertRaises(
205        TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
206    self.assertRaises(
207        TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
208    self.assertRaises(
209        TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
210    self.assertRaises(
211        TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
212    self.assertRaises(
213        TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
214    self.assertRaises(
215        TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
216    self.assertRaises(
217        TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
218    self.assertRaises(
219        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
220    self.assertRaises(
221        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
222
223  def testConstructorInvalidatesCachedByteSize(self):
224    message = unittest_pb2.TestAllTypes(optional_int32 = 12)
225    self.assertEquals(2, message.ByteSize())
226
227    message = unittest_pb2.TestAllTypes(
228        optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
229    self.assertEquals(3, message.ByteSize())
230
231    message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
232    self.assertEquals(3, message.ByteSize())
233
234    message = unittest_pb2.TestAllTypes(
235        repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
236    self.assertEquals(3, message.ByteSize())
237
238  def testSimpleHasBits(self):
239    # Test a scalar.
240    proto = unittest_pb2.TestAllTypes()
241    self.assertTrue(not proto.HasField('optional_int32'))
242    self.assertEqual(0, proto.optional_int32)
243    # HasField() shouldn't be true if all we've done is
244    # read the default value.
245    self.assertTrue(not proto.HasField('optional_int32'))
246    proto.optional_int32 = 1
247    # Setting a value however *should* set the "has" bit.
248    self.assertTrue(proto.HasField('optional_int32'))
249    proto.ClearField('optional_int32')
250    # And clearing that value should unset the "has" bit.
251    self.assertTrue(not proto.HasField('optional_int32'))
252
253  def testHasBitsWithSinglyNestedScalar(self):
254    # Helper used to test foreign messages and groups.
255    #
256    # composite_field_name should be the name of a non-repeated
257    # composite (i.e., foreign or group) field in TestAllTypes,
258    # and scalar_field_name should be the name of an integer-valued
259    # scalar field within that composite.
260    #
261    # I never thought I'd miss C++ macros and templates so much. :(
262    # This helper is semantically just:
263    #
264    #   assert proto.composite_field.scalar_field == 0
265    #   assert not proto.composite_field.HasField('scalar_field')
266    #   assert not proto.HasField('composite_field')
267    #
268    #   proto.composite_field.scalar_field = 10
269    #   old_composite_field = proto.composite_field
270    #
271    #   assert proto.composite_field.scalar_field == 10
272    #   assert proto.composite_field.HasField('scalar_field')
273    #   assert proto.HasField('composite_field')
274    #
275    #   proto.ClearField('composite_field')
276    #
277    #   assert not proto.composite_field.HasField('scalar_field')
278    #   assert not proto.HasField('composite_field')
279    #   assert proto.composite_field.scalar_field == 0
280    #
281    #   # Now ensure that ClearField('composite_field') disconnected
282    #   # the old field object from the object tree...
283    #   assert old_composite_field is not proto.composite_field
284    #   old_composite_field.scalar_field = 20
285    #   assert not proto.composite_field.HasField('scalar_field')
286    #   assert not proto.HasField('composite_field')
287    def TestCompositeHasBits(composite_field_name, scalar_field_name):
288      proto = unittest_pb2.TestAllTypes()
289      # First, check that we can get the scalar value, and see that it's the
290      # default (0), but that proto.HasField('omposite') and
291      # proto.composite.HasField('scalar') will still return False.
292      composite_field = getattr(proto, composite_field_name)
293      original_scalar_value = getattr(composite_field, scalar_field_name)
294      self.assertEqual(0, original_scalar_value)
295      # Assert that the composite object does not "have" the scalar.
296      self.assertTrue(not composite_field.HasField(scalar_field_name))
297      # Assert that proto does not "have" the composite field.
298      self.assertTrue(not proto.HasField(composite_field_name))
299
300      # Now set the scalar within the composite field.  Ensure that the setting
301      # is reflected, and that proto.HasField('composite') and
302      # proto.composite.HasField('scalar') now both return True.
303      new_val = 20
304      setattr(composite_field, scalar_field_name, new_val)
305      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
306      # Hold on to a reference to the current composite_field object.
307      old_composite_field = composite_field
308      # Assert that the has methods now return true.
309      self.assertTrue(composite_field.HasField(scalar_field_name))
310      self.assertTrue(proto.HasField(composite_field_name))
311
312      # Now call the clear method...
313      proto.ClearField(composite_field_name)
314
315      # ...and ensure that the "has" bits are all back to False...
316      composite_field = getattr(proto, composite_field_name)
317      self.assertTrue(not composite_field.HasField(scalar_field_name))
318      self.assertTrue(not proto.HasField(composite_field_name))
319      # ...and ensure that the scalar field has returned to its default.
320      self.assertEqual(0, getattr(composite_field, scalar_field_name))
321
322      self.assertTrue(old_composite_field is not composite_field)
323      setattr(old_composite_field, scalar_field_name, new_val)
324      self.assertTrue(not composite_field.HasField(scalar_field_name))
325      self.assertTrue(not proto.HasField(composite_field_name))
326      self.assertEqual(0, getattr(composite_field, scalar_field_name))
327
328    # Test simple, single-level nesting when we set a scalar.
329    TestCompositeHasBits('optionalgroup', 'a')
330    TestCompositeHasBits('optional_nested_message', 'bb')
331    TestCompositeHasBits('optional_foreign_message', 'c')
332    TestCompositeHasBits('optional_import_message', 'd')
333
334  def testReferencesToNestedMessage(self):
335    proto = unittest_pb2.TestAllTypes()
336    nested = proto.optional_nested_message
337    del proto
338    # A previous version had a bug where this would raise an exception when
339    # hitting a now-dead weak reference.
340    nested.bb = 23
341
342  def testDisconnectingNestedMessageBeforeSettingField(self):
343    proto = unittest_pb2.TestAllTypes()
344    nested = proto.optional_nested_message
345    proto.ClearField('optional_nested_message')  # Should disconnect from parent
346    self.assertTrue(nested is not proto.optional_nested_message)
347    nested.bb = 23
348    self.assertTrue(not proto.HasField('optional_nested_message'))
349    self.assertEqual(0, proto.optional_nested_message.bb)
350
351  def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
352    proto = unittest_pb2.TestAllTypes()
353    nested = proto.optional_nested_message
354    proto.ClearField('optional_nested_message')
355    del proto
356    del nested
357    # Force a garbage collect so that the underlying CMessages are freed along
358    # with the Messages they point to. This is to make sure we're not deleting
359    # default message instances.
360    gc.collect()
361    proto = unittest_pb2.TestAllTypes()
362    nested = proto.optional_nested_message
363
364  def testDisconnectingNestedMessageAfterSettingField(self):
365    proto = unittest_pb2.TestAllTypes()
366    nested = proto.optional_nested_message
367    nested.bb = 5
368    self.assertTrue(proto.HasField('optional_nested_message'))
369    proto.ClearField('optional_nested_message')  # Should disconnect from parent
370    self.assertEqual(5, nested.bb)
371    self.assertEqual(0, proto.optional_nested_message.bb)
372    self.assertTrue(nested is not proto.optional_nested_message)
373    nested.bb = 23
374    self.assertTrue(not proto.HasField('optional_nested_message'))
375    self.assertEqual(0, proto.optional_nested_message.bb)
376
377  def testDisconnectingNestedMessageBeforeGettingField(self):
378    proto = unittest_pb2.TestAllTypes()
379    self.assertTrue(not proto.HasField('optional_nested_message'))
380    proto.ClearField('optional_nested_message')
381    self.assertTrue(not proto.HasField('optional_nested_message'))
382
383  def testDisconnectingNestedMessageAfterMerge(self):
384    # This test exercises the code path that does not use ReleaseMessage().
385    # The underlying fear is that if we use ReleaseMessage() incorrectly,
386    # we will have memory leaks.  It's hard to check that that doesn't happen,
387    # but at least we can exercise that code path to make sure it works.
388    proto1 = unittest_pb2.TestAllTypes()
389    proto2 = unittest_pb2.TestAllTypes()
390    proto2.optional_nested_message.bb = 5
391    proto1.MergeFrom(proto2)
392    self.assertTrue(proto1.HasField('optional_nested_message'))
393    proto1.ClearField('optional_nested_message')
394    self.assertTrue(not proto1.HasField('optional_nested_message'))
395
396  def testDisconnectingLazyNestedMessage(self):
397    # This test exercises releasing a nested message that is lazy. This test
398    # only exercises real code in the C++ implementation as Python does not
399    # support lazy parsing, but the current C++ implementation results in
400    # memory corruption and a crash.
401    if api_implementation.Type() != 'python':
402      return
403    proto = unittest_pb2.TestAllTypes()
404    proto.optional_lazy_message.bb = 5
405    proto.ClearField('optional_lazy_message')
406    del proto
407    gc.collect()
408
409  def testHasBitsWhenModifyingRepeatedFields(self):
410    # Test nesting when we add an element to a repeated field in a submessage.
411    proto = unittest_pb2.TestNestedMessageHasBits()
412    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
413    self.assertEqual(
414        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
415    self.assertTrue(proto.HasField('optional_nested_message'))
416
417    # Do the same test, but with a repeated composite field within the
418    # submessage.
419    proto.ClearField('optional_nested_message')
420    self.assertTrue(not proto.HasField('optional_nested_message'))
421    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
422    self.assertTrue(proto.HasField('optional_nested_message'))
423
424  def testHasBitsForManyLevelsOfNesting(self):
425    # Test nesting many levels deep.
426    recursive_proto = unittest_pb2.TestMutualRecursionA()
427    self.assertTrue(not recursive_proto.HasField('bb'))
428    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
429    self.assertTrue(not recursive_proto.HasField('bb'))
430    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
431    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
432    self.assertTrue(recursive_proto.HasField('bb'))
433    self.assertTrue(recursive_proto.bb.HasField('a'))
434    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
435    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
436    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
437    self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
438    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
439
440  def testSingularListFields(self):
441    proto = unittest_pb2.TestAllTypes()
442    proto.optional_fixed32 = 1
443    proto.optional_int32 = 5
444    proto.optional_string = 'foo'
445    # Access sub-message but don't set it yet.
446    nested_message = proto.optional_nested_message
447    self.assertEqual(
448      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
449        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
450        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
451      proto.ListFields())
452
453    proto.optional_nested_message.bb = 123
454    self.assertEqual(
455      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
456        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
457        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
458        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
459             nested_message) ],
460      proto.ListFields())
461
462  def testRepeatedListFields(self):
463    proto = unittest_pb2.TestAllTypes()
464    proto.repeated_fixed32.append(1)
465    proto.repeated_int32.append(5)
466    proto.repeated_int32.append(11)
467    proto.repeated_string.extend(['foo', 'bar'])
468    proto.repeated_string.extend([])
469    proto.repeated_string.append('baz')
470    proto.repeated_string.extend(str(x) for x in xrange(2))
471    proto.optional_int32 = 21
472    proto.repeated_bool  # Access but don't set anything; should not be listed.
473    self.assertEqual(
474      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
475        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
476        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
477        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
478          ['foo', 'bar', 'baz', '0', '1']) ],
479      proto.ListFields())
480
481  def testSingularListExtensions(self):
482    proto = unittest_pb2.TestAllExtensions()
483    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
484    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
485    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
486    self.assertEqual(
487      [ (unittest_pb2.optional_int32_extension  , 5),
488        (unittest_pb2.optional_fixed32_extension, 1),
489        (unittest_pb2.optional_string_extension , 'foo') ],
490      proto.ListFields())
491
492  def testRepeatedListExtensions(self):
493    proto = unittest_pb2.TestAllExtensions()
494    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
495    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
496    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
497    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
498    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
499    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
500    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
501    self.assertEqual(
502      [ (unittest_pb2.optional_int32_extension  , 21),
503        (unittest_pb2.repeated_int32_extension  , [5, 11]),
504        (unittest_pb2.repeated_fixed32_extension, [1]),
505        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
506      proto.ListFields())
507
508  def testListFieldsAndExtensions(self):
509    proto = unittest_pb2.TestFieldOrderings()
510    test_util.SetAllFieldsAndExtensions(proto)
511    unittest_pb2.my_extension_int
512    self.assertEqual(
513      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
514        (unittest_pb2.my_extension_int               , 23),
515        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
516        (unittest_pb2.my_extension_string            , 'bar'),
517        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
518      proto.ListFields())
519
520  def testDefaultValues(self):
521    proto = unittest_pb2.TestAllTypes()
522    self.assertEqual(0, proto.optional_int32)
523    self.assertEqual(0, proto.optional_int64)
524    self.assertEqual(0, proto.optional_uint32)
525    self.assertEqual(0, proto.optional_uint64)
526    self.assertEqual(0, proto.optional_sint32)
527    self.assertEqual(0, proto.optional_sint64)
528    self.assertEqual(0, proto.optional_fixed32)
529    self.assertEqual(0, proto.optional_fixed64)
530    self.assertEqual(0, proto.optional_sfixed32)
531    self.assertEqual(0, proto.optional_sfixed64)
532    self.assertEqual(0.0, proto.optional_float)
533    self.assertEqual(0.0, proto.optional_double)
534    self.assertEqual(False, proto.optional_bool)
535    self.assertEqual('', proto.optional_string)
536    self.assertEqual('', proto.optional_bytes)
537
538    self.assertEqual(41, proto.default_int32)
539    self.assertEqual(42, proto.default_int64)
540    self.assertEqual(43, proto.default_uint32)
541    self.assertEqual(44, proto.default_uint64)
542    self.assertEqual(-45, proto.default_sint32)
543    self.assertEqual(46, proto.default_sint64)
544    self.assertEqual(47, proto.default_fixed32)
545    self.assertEqual(48, proto.default_fixed64)
546    self.assertEqual(49, proto.default_sfixed32)
547    self.assertEqual(-50, proto.default_sfixed64)
548    self.assertEqual(51.5, proto.default_float)
549    self.assertEqual(52e3, proto.default_double)
550    self.assertEqual(True, proto.default_bool)
551    self.assertEqual('hello', proto.default_string)
552    self.assertEqual('world', proto.default_bytes)
553    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
554    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
555    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
556                     proto.default_import_enum)
557
558    proto = unittest_pb2.TestExtremeDefaultValues()
559    self.assertEqual(u'\u1234', proto.utf8_string)
560
561  def testHasFieldWithUnknownFieldName(self):
562    proto = unittest_pb2.TestAllTypes()
563    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
564
565  def testClearFieldWithUnknownFieldName(self):
566    proto = unittest_pb2.TestAllTypes()
567    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
568
569  def testDisallowedAssignments(self):
570    # It's illegal to assign values directly to repeated fields
571    # or to nonrepeated composite fields.  Ensure that this fails.
572    proto = unittest_pb2.TestAllTypes()
573    # Repeated fields.
574    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
575    # Lists shouldn't work, either.
576    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
577    # Composite fields.
578    self.assertRaises(AttributeError, setattr, proto,
579                      'optional_nested_message', 23)
580    # Assignment to a repeated nested message field without specifying
581    # the index in the array of nested messages.
582    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
583                      'bb', 34)
584    # Assignment to an attribute of a repeated field.
585    self.assertRaises(AttributeError, setattr, proto.repeated_float,
586                      'some_attribute', 34)
587    # proto.nonexistent_field = 23 should fail as well.
588    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
589
590  def testSingleScalarTypeSafety(self):
591    proto = unittest_pb2.TestAllTypes()
592    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
593    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
594    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
595    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
596
597  def testSingleScalarBoundsChecking(self):
598    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
599      pb = unittest_pb2.TestAllTypes()
600      setattr(pb, field_name, expected_min)
601      self.assertEqual(expected_min, getattr(pb, field_name))
602      setattr(pb, field_name, expected_max)
603      self.assertEqual(expected_max, getattr(pb, field_name))
604      self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
605      self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
606
607    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
608    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
609    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
610    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
611
612    pb = unittest_pb2.TestAllTypes()
613    pb.optional_nested_enum = 1
614    self.assertEqual(1, pb.optional_nested_enum)
615
616    # Invalid enum values.
617    pb.optional_nested_enum = 0
618    self.assertEqual(0, pb.optional_nested_enum)
619
620    bytes_size_before = pb.ByteSize()
621
622    pb.optional_nested_enum = 4
623    self.assertEqual(4, pb.optional_nested_enum)
624
625    pb.optional_nested_enum = 0
626    self.assertEqual(0, pb.optional_nested_enum)
627
628    # Make sure that setting the same enum field doesn't just add unknown
629    # fields (but overwrites them).
630    self.assertEqual(bytes_size_before, pb.ByteSize())
631
632    # Is the invalid value preserved after serialization?
633    serialized = pb.SerializeToString()
634    pb2 = unittest_pb2.TestAllTypes()
635    pb2.ParseFromString(serialized)
636    self.assertEqual(0, pb2.optional_nested_enum)
637    self.assertEqual(pb, pb2)
638
639  def testRepeatedScalarTypeSafety(self):
640    proto = unittest_pb2.TestAllTypes()
641    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
642    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
643    self.assertRaises(TypeError, proto.repeated_string, 10)
644    self.assertRaises(TypeError, proto.repeated_bytes, 10)
645
646    proto.repeated_int32.append(10)
647    proto.repeated_int32[0] = 23
648    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
649    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
650
651    # Repeated enums tests.
652    #proto.repeated_nested_enum.append(0)
653
654  def testSingleScalarGettersAndSetters(self):
655    proto = unittest_pb2.TestAllTypes()
656    self.assertEqual(0, proto.optional_int32)
657    proto.optional_int32 = 1
658    self.assertEqual(1, proto.optional_int32)
659
660    proto.optional_uint64 = 0xffffffffffff
661    self.assertEqual(0xffffffffffff, proto.optional_uint64)
662    proto.optional_uint64 = 0xffffffffffffffff
663    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
664    # TODO(robinson): Test all other scalar field types.
665
666  def testSingleScalarClearField(self):
667    proto = unittest_pb2.TestAllTypes()
668    # Should be allowed to clear something that's not there (a no-op).
669    proto.ClearField('optional_int32')
670    proto.optional_int32 = 1
671    self.assertTrue(proto.HasField('optional_int32'))
672    proto.ClearField('optional_int32')
673    self.assertEqual(0, proto.optional_int32)
674    self.assertTrue(not proto.HasField('optional_int32'))
675    # TODO(robinson): Test all other scalar field types.
676
677  def testEnums(self):
678    proto = unittest_pb2.TestAllTypes()
679    self.assertEqual(1, proto.FOO)
680    self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
681    self.assertEqual(2, proto.BAR)
682    self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
683    self.assertEqual(3, proto.BAZ)
684    self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
685
686  def testEnum_Name(self):
687    self.assertEqual('FOREIGN_FOO',
688                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
689    self.assertEqual('FOREIGN_BAR',
690                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
691    self.assertEqual('FOREIGN_BAZ',
692                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
693    self.assertRaises(ValueError,
694                      unittest_pb2.ForeignEnum.Name, 11312)
695
696    proto = unittest_pb2.TestAllTypes()
697    self.assertEqual('FOO',
698                     proto.NestedEnum.Name(proto.FOO))
699    self.assertEqual('FOO',
700                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
701    self.assertEqual('BAR',
702                     proto.NestedEnum.Name(proto.BAR))
703    self.assertEqual('BAR',
704                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
705    self.assertEqual('BAZ',
706                     proto.NestedEnum.Name(proto.BAZ))
707    self.assertEqual('BAZ',
708                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
709    self.assertRaises(ValueError,
710                      proto.NestedEnum.Name, 11312)
711    self.assertRaises(ValueError,
712                      unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
713
714  def testEnum_Value(self):
715    self.assertEqual(unittest_pb2.FOREIGN_FOO,
716                     unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
717    self.assertEqual(unittest_pb2.FOREIGN_BAR,
718                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
719    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
720                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
721    self.assertRaises(ValueError,
722                      unittest_pb2.ForeignEnum.Value, 'FO')
723
724    proto = unittest_pb2.TestAllTypes()
725    self.assertEqual(proto.FOO,
726                     proto.NestedEnum.Value('FOO'))
727    self.assertEqual(proto.FOO,
728                     unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
729    self.assertEqual(proto.BAR,
730                     proto.NestedEnum.Value('BAR'))
731    self.assertEqual(proto.BAR,
732                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
733    self.assertEqual(proto.BAZ,
734                     proto.NestedEnum.Value('BAZ'))
735    self.assertEqual(proto.BAZ,
736                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
737    self.assertRaises(ValueError,
738                      proto.NestedEnum.Value, 'Foo')
739    self.assertRaises(ValueError,
740                      unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
741
742  def testEnum_KeysAndValues(self):
743    self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
744                     unittest_pb2.ForeignEnum.keys())
745    self.assertEqual([4, 5, 6],
746                     unittest_pb2.ForeignEnum.values())
747    self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
748                      ('FOREIGN_BAZ', 6)],
749                     unittest_pb2.ForeignEnum.items())
750
751    proto = unittest_pb2.TestAllTypes()
752    self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys())
753    self.assertEqual([1, 2, 3], proto.NestedEnum.values())
754    self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)],
755                     proto.NestedEnum.items())
756
757  def testRepeatedScalars(self):
758    proto = unittest_pb2.TestAllTypes()
759
760    self.assertTrue(not proto.repeated_int32)
761    self.assertEqual(0, len(proto.repeated_int32))
762    proto.repeated_int32.append(5)
763    proto.repeated_int32.append(10)
764    proto.repeated_int32.append(15)
765    self.assertTrue(proto.repeated_int32)
766    self.assertEqual(3, len(proto.repeated_int32))
767
768    self.assertEqual([5, 10, 15], proto.repeated_int32)
769
770    # Test single retrieval.
771    self.assertEqual(5, proto.repeated_int32[0])
772    self.assertEqual(15, proto.repeated_int32[-1])
773    # Test out-of-bounds indices.
774    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
775    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
776    # Test incorrect types passed to __getitem__.
777    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
778    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
779
780    # Test single assignment.
781    proto.repeated_int32[1] = 20
782    self.assertEqual([5, 20, 15], proto.repeated_int32)
783
784    # Test insertion.
785    proto.repeated_int32.insert(1, 25)
786    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
787
788    # Test slice retrieval.
789    proto.repeated_int32.append(30)
790    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
791    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
792
793    # Test slice assignment with an iterator
794    proto.repeated_int32[1:4] = (i for i in xrange(3))
795    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
796
797    # Test slice assignment.
798    proto.repeated_int32[1:4] = [35, 40, 45]
799    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
800
801    # Test that we can use the field as an iterator.
802    result = []
803    for i in proto.repeated_int32:
804      result.append(i)
805    self.assertEqual([5, 35, 40, 45, 30], result)
806
807    # Test single deletion.
808    del proto.repeated_int32[2]
809    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
810
811    # Test slice deletion.
812    del proto.repeated_int32[2:]
813    self.assertEqual([5, 35], proto.repeated_int32)
814
815    # Test extending.
816    proto.repeated_int32.extend([3, 13])
817    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
818
819    # Test clearing.
820    proto.ClearField('repeated_int32')
821    self.assertTrue(not proto.repeated_int32)
822    self.assertEqual(0, len(proto.repeated_int32))
823
824    proto.repeated_int32.append(1)
825    self.assertEqual(1, proto.repeated_int32[-1])
826    # Test assignment to a negative index.
827    proto.repeated_int32[-1] = 2
828    self.assertEqual(2, proto.repeated_int32[-1])
829
830    # Test deletion at negative indices.
831    proto.repeated_int32[:] = [0, 1, 2, 3]
832    del proto.repeated_int32[-1]
833    self.assertEqual([0, 1, 2], proto.repeated_int32)
834
835    del proto.repeated_int32[-2]
836    self.assertEqual([0, 2], proto.repeated_int32)
837
838    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
839    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
840
841    del proto.repeated_int32[-2:-1]
842    self.assertEqual([2], proto.repeated_int32)
843
844    del proto.repeated_int32[100:10000]
845    self.assertEqual([2], proto.repeated_int32)
846
847  def testRepeatedScalarsRemove(self):
848    proto = unittest_pb2.TestAllTypes()
849
850    self.assertTrue(not proto.repeated_int32)
851    self.assertEqual(0, len(proto.repeated_int32))
852    proto.repeated_int32.append(5)
853    proto.repeated_int32.append(10)
854    proto.repeated_int32.append(5)
855    proto.repeated_int32.append(5)
856
857    self.assertEqual(4, len(proto.repeated_int32))
858    proto.repeated_int32.remove(5)
859    self.assertEqual(3, len(proto.repeated_int32))
860    self.assertEqual(10, proto.repeated_int32[0])
861    self.assertEqual(5, proto.repeated_int32[1])
862    self.assertEqual(5, proto.repeated_int32[2])
863
864    proto.repeated_int32.remove(5)
865    self.assertEqual(2, len(proto.repeated_int32))
866    self.assertEqual(10, proto.repeated_int32[0])
867    self.assertEqual(5, proto.repeated_int32[1])
868
869    proto.repeated_int32.remove(10)
870    self.assertEqual(1, len(proto.repeated_int32))
871    self.assertEqual(5, proto.repeated_int32[0])
872
873    # Remove a non-existent element.
874    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
875
876  def testRepeatedComposites(self):
877    proto = unittest_pb2.TestAllTypes()
878    self.assertTrue(not proto.repeated_nested_message)
879    self.assertEqual(0, len(proto.repeated_nested_message))
880    m0 = proto.repeated_nested_message.add()
881    m1 = proto.repeated_nested_message.add()
882    self.assertTrue(proto.repeated_nested_message)
883    self.assertEqual(2, len(proto.repeated_nested_message))
884    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
885    self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
886
887    # Test out-of-bounds indices.
888    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
889                      1234)
890    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
891                      -1234)
892
893    # Test incorrect types passed to __getitem__.
894    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
895                      'foo')
896    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
897                      None)
898
899    # Test slice retrieval.
900    m2 = proto.repeated_nested_message.add()
901    m3 = proto.repeated_nested_message.add()
902    m4 = proto.repeated_nested_message.add()
903    self.assertListsEqual(
904        [m1, m2, m3], proto.repeated_nested_message[1:4])
905    self.assertListsEqual(
906        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
907    self.assertListsEqual(
908        [m0, m1], proto.repeated_nested_message[:2])
909    self.assertListsEqual(
910        [m2, m3, m4], proto.repeated_nested_message[2:])
911    self.assertEqual(
912        m0, proto.repeated_nested_message[0])
913    self.assertListsEqual(
914        [m0], proto.repeated_nested_message[:1])
915
916    # Test that we can use the field as an iterator.
917    result = []
918    for i in proto.repeated_nested_message:
919      result.append(i)
920    self.assertListsEqual([m0, m1, m2, m3, m4], result)
921
922    # Test single deletion.
923    del proto.repeated_nested_message[2]
924    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
925
926    # Test slice deletion.
927    del proto.repeated_nested_message[2:]
928    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
929
930    # Test extending.
931    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
932    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
933    proto.repeated_nested_message.extend([n1,n2])
934    self.assertEqual(4, len(proto.repeated_nested_message))
935    self.assertEqual(n1, proto.repeated_nested_message[2])
936    self.assertEqual(n2, proto.repeated_nested_message[3])
937
938    # Test clearing.
939    proto.ClearField('repeated_nested_message')
940    self.assertTrue(not proto.repeated_nested_message)
941    self.assertEqual(0, len(proto.repeated_nested_message))
942
943    # Test constructing an element while adding it.
944    proto.repeated_nested_message.add(bb=23)
945    self.assertEqual(1, len(proto.repeated_nested_message))
946    self.assertEqual(23, proto.repeated_nested_message[0].bb)
947
948  def testRepeatedCompositeRemove(self):
949    proto = unittest_pb2.TestAllTypes()
950
951    self.assertEqual(0, len(proto.repeated_nested_message))
952    m0 = proto.repeated_nested_message.add()
953    # Need to set some differentiating variable so m0 != m1 != m2:
954    m0.bb = len(proto.repeated_nested_message)
955    m1 = proto.repeated_nested_message.add()
956    m1.bb = len(proto.repeated_nested_message)
957    self.assertTrue(m0 != m1)
958    m2 = proto.repeated_nested_message.add()
959    m2.bb = len(proto.repeated_nested_message)
960    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
961
962    self.assertEqual(3, len(proto.repeated_nested_message))
963    proto.repeated_nested_message.remove(m0)
964    self.assertEqual(2, len(proto.repeated_nested_message))
965    self.assertEqual(m1, proto.repeated_nested_message[0])
966    self.assertEqual(m2, proto.repeated_nested_message[1])
967
968    # Removing m0 again or removing None should raise error
969    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
970    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
971    self.assertEqual(2, len(proto.repeated_nested_message))
972
973    proto.repeated_nested_message.remove(m2)
974    self.assertEqual(1, len(proto.repeated_nested_message))
975    self.assertEqual(m1, proto.repeated_nested_message[0])
976
977  def testHandWrittenReflection(self):
978    # Hand written extensions are only supported by the pure-Python
979    # implementation of the API.
980    if api_implementation.Type() != 'python':
981      return
982
983    FieldDescriptor = descriptor.FieldDescriptor
984    foo_field_descriptor = FieldDescriptor(
985        name='foo_field', full_name='MyProto.foo_field',
986        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
987        cpp_type=FieldDescriptor.CPPTYPE_INT64,
988        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
989        containing_type=None, message_type=None, enum_type=None,
990        is_extension=False, extension_scope=None,
991        options=descriptor_pb2.FieldOptions())
992    mydescriptor = descriptor.Descriptor(
993        name='MyProto', full_name='MyProto', filename='ignored',
994        containing_type=None, nested_types=[], enum_types=[],
995        fields=[foo_field_descriptor], extensions=[],
996        options=descriptor_pb2.MessageOptions())
997    class MyProtoClass(message.Message):
998      DESCRIPTOR = mydescriptor
999      __metaclass__ = reflection.GeneratedProtocolMessageType
1000    myproto_instance = MyProtoClass()
1001    self.assertEqual(0, myproto_instance.foo_field)
1002    self.assertTrue(not myproto_instance.HasField('foo_field'))
1003    myproto_instance.foo_field = 23
1004    self.assertEqual(23, myproto_instance.foo_field)
1005    self.assertTrue(myproto_instance.HasField('foo_field'))
1006
1007  def testDescriptorProtoSupport(self):
1008    # Hand written descriptors/reflection are only supported by the pure-Python
1009    # implementation of the API.
1010    if api_implementation.Type() != 'python':
1011      return
1012
1013    def AddDescriptorField(proto, field_name, field_type):
1014      AddDescriptorField.field_index += 1
1015      new_field = proto.field.add()
1016      new_field.name = field_name
1017      new_field.type = field_type
1018      new_field.number = AddDescriptorField.field_index
1019      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1020
1021    AddDescriptorField.field_index = 0
1022
1023    desc_proto = descriptor_pb2.DescriptorProto()
1024    desc_proto.name = 'Car'
1025    fdp = descriptor_pb2.FieldDescriptorProto
1026    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1027    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1028    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1029    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1030    # Add a repeated field
1031    AddDescriptorField.field_index += 1
1032    new_field = desc_proto.field.add()
1033    new_field.name = 'owners'
1034    new_field.type = fdp.TYPE_STRING
1035    new_field.number = AddDescriptorField.field_index
1036    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1037
1038    desc = descriptor.MakeDescriptor(desc_proto)
1039    self.assertTrue(desc.fields_by_name.has_key('name'))
1040    self.assertTrue(desc.fields_by_name.has_key('year'))
1041    self.assertTrue(desc.fields_by_name.has_key('automatic'))
1042    self.assertTrue(desc.fields_by_name.has_key('price'))
1043    self.assertTrue(desc.fields_by_name.has_key('owners'))
1044
1045    class CarMessage(message.Message):
1046      __metaclass__ = reflection.GeneratedProtocolMessageType
1047      DESCRIPTOR = desc
1048
1049    prius = CarMessage()
1050    prius.name = 'prius'
1051    prius.year = 2010
1052    prius.automatic = True
1053    prius.price = 25134.75
1054    prius.owners.extend(['bob', 'susan'])
1055
1056    serialized_prius = prius.SerializeToString()
1057    new_prius = reflection.ParseMessage(desc, serialized_prius)
1058    self.assertTrue(new_prius is not prius)
1059    self.assertEqual(prius, new_prius)
1060
1061    # these are unnecessary assuming message equality works as advertised but
1062    # explicitly check to be safe since we're mucking about in metaclass foo
1063    self.assertEqual(prius.name, new_prius.name)
1064    self.assertEqual(prius.year, new_prius.year)
1065    self.assertEqual(prius.automatic, new_prius.automatic)
1066    self.assertEqual(prius.price, new_prius.price)
1067    self.assertEqual(prius.owners, new_prius.owners)
1068
1069  def testTopLevelExtensionsForOptionalScalar(self):
1070    extendee_proto = unittest_pb2.TestAllExtensions()
1071    extension = unittest_pb2.optional_int32_extension
1072    self.assertTrue(not extendee_proto.HasExtension(extension))
1073    self.assertEqual(0, extendee_proto.Extensions[extension])
1074    # As with normal scalar fields, just doing a read doesn't actually set the
1075    # "has" bit.
1076    self.assertTrue(not extendee_proto.HasExtension(extension))
1077    # Actually set the thing.
1078    extendee_proto.Extensions[extension] = 23
1079    self.assertEqual(23, extendee_proto.Extensions[extension])
1080    self.assertTrue(extendee_proto.HasExtension(extension))
1081    # Ensure that clearing works as well.
1082    extendee_proto.ClearExtension(extension)
1083    self.assertEqual(0, extendee_proto.Extensions[extension])
1084    self.assertTrue(not extendee_proto.HasExtension(extension))
1085
1086  def testTopLevelExtensionsForRepeatedScalar(self):
1087    extendee_proto = unittest_pb2.TestAllExtensions()
1088    extension = unittest_pb2.repeated_string_extension
1089    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1090    extendee_proto.Extensions[extension].append('foo')
1091    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1092    string_list = extendee_proto.Extensions[extension]
1093    extendee_proto.ClearExtension(extension)
1094    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1095    self.assertTrue(string_list is not extendee_proto.Extensions[extension])
1096    # Shouldn't be allowed to do Extensions[extension] = 'a'
1097    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1098                      extension, 'a')
1099
1100  def testTopLevelExtensionsForOptionalMessage(self):
1101    extendee_proto = unittest_pb2.TestAllExtensions()
1102    extension = unittest_pb2.optional_foreign_message_extension
1103    self.assertTrue(not extendee_proto.HasExtension(extension))
1104    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1105    # As with normal (non-extension) fields, merely reading from the
1106    # thing shouldn't set the "has" bit.
1107    self.assertTrue(not extendee_proto.HasExtension(extension))
1108    extendee_proto.Extensions[extension].c = 23
1109    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1110    self.assertTrue(extendee_proto.HasExtension(extension))
1111    # Save a reference here.
1112    foreign_message = extendee_proto.Extensions[extension]
1113    extendee_proto.ClearExtension(extension)
1114    self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
1115    # Setting a field on foreign_message now shouldn't set
1116    # any "has" bits on extendee_proto.
1117    foreign_message.c = 42
1118    self.assertEqual(42, foreign_message.c)
1119    self.assertTrue(foreign_message.HasField('c'))
1120    self.assertTrue(not extendee_proto.HasExtension(extension))
1121    # Shouldn't be allowed to do Extensions[extension] = 'a'
1122    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1123                      extension, 'a')
1124
1125  def testTopLevelExtensionsForRepeatedMessage(self):
1126    extendee_proto = unittest_pb2.TestAllExtensions()
1127    extension = unittest_pb2.repeatedgroup_extension
1128    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1129    group = extendee_proto.Extensions[extension].add()
1130    group.a = 23
1131    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1132    group.a = 42
1133    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1134    group_list = extendee_proto.Extensions[extension]
1135    extendee_proto.ClearExtension(extension)
1136    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1137    self.assertTrue(group_list is not extendee_proto.Extensions[extension])
1138    # Shouldn't be allowed to do Extensions[extension] = 'a'
1139    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1140                      extension, 'a')
1141
1142  def testNestedExtensions(self):
1143    extendee_proto = unittest_pb2.TestAllExtensions()
1144    extension = unittest_pb2.TestRequired.single
1145
1146    # We just test the non-repeated case.
1147    self.assertTrue(not extendee_proto.HasExtension(extension))
1148    required = extendee_proto.Extensions[extension]
1149    self.assertEqual(0, required.a)
1150    self.assertTrue(not extendee_proto.HasExtension(extension))
1151    required.a = 23
1152    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1153    self.assertTrue(extendee_proto.HasExtension(extension))
1154    extendee_proto.ClearExtension(extension)
1155    self.assertTrue(required is not extendee_proto.Extensions[extension])
1156    self.assertTrue(not extendee_proto.HasExtension(extension))
1157
1158  # If message A directly contains message B, and
1159  # a.HasField('b') is currently False, then mutating any
1160  # extension in B should change a.HasField('b') to True
1161  # (and so on up the object tree).
1162  def testHasBitsForAncestorsOfExtendedMessage(self):
1163    # Optional scalar extension.
1164    toplevel = more_extensions_pb2.TopLevelMessage()
1165    self.assertTrue(not toplevel.HasField('submessage'))
1166    self.assertEqual(0, toplevel.submessage.Extensions[
1167        more_extensions_pb2.optional_int_extension])
1168    self.assertTrue(not toplevel.HasField('submessage'))
1169    toplevel.submessage.Extensions[
1170        more_extensions_pb2.optional_int_extension] = 23
1171    self.assertEqual(23, toplevel.submessage.Extensions[
1172        more_extensions_pb2.optional_int_extension])
1173    self.assertTrue(toplevel.HasField('submessage'))
1174
1175    # Repeated scalar extension.
1176    toplevel = more_extensions_pb2.TopLevelMessage()
1177    self.assertTrue(not toplevel.HasField('submessage'))
1178    self.assertEqual([], toplevel.submessage.Extensions[
1179        more_extensions_pb2.repeated_int_extension])
1180    self.assertTrue(not toplevel.HasField('submessage'))
1181    toplevel.submessage.Extensions[
1182        more_extensions_pb2.repeated_int_extension].append(23)
1183    self.assertEqual([23], toplevel.submessage.Extensions[
1184        more_extensions_pb2.repeated_int_extension])
1185    self.assertTrue(toplevel.HasField('submessage'))
1186
1187    # Optional message extension.
1188    toplevel = more_extensions_pb2.TopLevelMessage()
1189    self.assertTrue(not toplevel.HasField('submessage'))
1190    self.assertEqual(0, toplevel.submessage.Extensions[
1191        more_extensions_pb2.optional_message_extension].foreign_message_int)
1192    self.assertTrue(not toplevel.HasField('submessage'))
1193    toplevel.submessage.Extensions[
1194        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1195    self.assertEqual(23, toplevel.submessage.Extensions[
1196        more_extensions_pb2.optional_message_extension].foreign_message_int)
1197    self.assertTrue(toplevel.HasField('submessage'))
1198
1199    # Repeated message extension.
1200    toplevel = more_extensions_pb2.TopLevelMessage()
1201    self.assertTrue(not toplevel.HasField('submessage'))
1202    self.assertEqual(0, len(toplevel.submessage.Extensions[
1203        more_extensions_pb2.repeated_message_extension]))
1204    self.assertTrue(not toplevel.HasField('submessage'))
1205    foreign = toplevel.submessage.Extensions[
1206        more_extensions_pb2.repeated_message_extension].add()
1207    self.assertEqual(foreign, toplevel.submessage.Extensions[
1208        more_extensions_pb2.repeated_message_extension][0])
1209    self.assertTrue(toplevel.HasField('submessage'))
1210
1211  def testDisconnectionAfterClearingEmptyMessage(self):
1212    toplevel = more_extensions_pb2.TopLevelMessage()
1213    extendee_proto = toplevel.submessage
1214    extension = more_extensions_pb2.optional_message_extension
1215    extension_proto = extendee_proto.Extensions[extension]
1216    extendee_proto.ClearExtension(extension)
1217    extension_proto.foreign_message_int = 23
1218
1219    self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
1220
1221  def testExtensionFailureModes(self):
1222    extendee_proto = unittest_pb2.TestAllExtensions()
1223
1224    # Try non-extension-handle arguments to HasExtension,
1225    # ClearExtension(), and Extensions[]...
1226    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1227    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1228    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1229    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1230
1231    # Try something that *is* an extension handle, just not for
1232    # this message...
1233    unknown_handle = more_extensions_pb2.optional_int_extension
1234    self.assertRaises(KeyError, extendee_proto.HasExtension,
1235                      unknown_handle)
1236    self.assertRaises(KeyError, extendee_proto.ClearExtension,
1237                      unknown_handle)
1238    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1239                      unknown_handle)
1240    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1241                      unknown_handle, 5)
1242
1243    # Try call HasExtension() with a valid handle, but for a
1244    # *repeated* field.  (Just as with non-extension repeated
1245    # fields, Has*() isn't supported for extension repeated fields).
1246    self.assertRaises(KeyError, extendee_proto.HasExtension,
1247                      unittest_pb2.repeated_string_extension)
1248
1249  def testStaticParseFrom(self):
1250    proto1 = unittest_pb2.TestAllTypes()
1251    test_util.SetAllFields(proto1)
1252
1253    string1 = proto1.SerializeToString()
1254    proto2 = unittest_pb2.TestAllTypes.FromString(string1)
1255
1256    # Messages should be equal.
1257    self.assertEqual(proto2, proto1)
1258
1259  def testMergeFromSingularField(self):
1260    # Test merge with just a singular field.
1261    proto1 = unittest_pb2.TestAllTypes()
1262    proto1.optional_int32 = 1
1263
1264    proto2 = unittest_pb2.TestAllTypes()
1265    # This shouldn't get overwritten.
1266    proto2.optional_string = 'value'
1267
1268    proto2.MergeFrom(proto1)
1269    self.assertEqual(1, proto2.optional_int32)
1270    self.assertEqual('value', proto2.optional_string)
1271
1272  def testMergeFromRepeatedField(self):
1273    # Test merge with just a repeated field.
1274    proto1 = unittest_pb2.TestAllTypes()
1275    proto1.repeated_int32.append(1)
1276    proto1.repeated_int32.append(2)
1277
1278    proto2 = unittest_pb2.TestAllTypes()
1279    proto2.repeated_int32.append(0)
1280    proto2.MergeFrom(proto1)
1281
1282    self.assertEqual(0, proto2.repeated_int32[0])
1283    self.assertEqual(1, proto2.repeated_int32[1])
1284    self.assertEqual(2, proto2.repeated_int32[2])
1285
1286  def testMergeFromOptionalGroup(self):
1287    # Test merge with an optional group.
1288    proto1 = unittest_pb2.TestAllTypes()
1289    proto1.optionalgroup.a = 12
1290    proto2 = unittest_pb2.TestAllTypes()
1291    proto2.MergeFrom(proto1)
1292    self.assertEqual(12, proto2.optionalgroup.a)
1293
1294  def testMergeFromRepeatedNestedMessage(self):
1295    # Test merge with a repeated nested message.
1296    proto1 = unittest_pb2.TestAllTypes()
1297    m = proto1.repeated_nested_message.add()
1298    m.bb = 123
1299    m = proto1.repeated_nested_message.add()
1300    m.bb = 321
1301
1302    proto2 = unittest_pb2.TestAllTypes()
1303    m = proto2.repeated_nested_message.add()
1304    m.bb = 999
1305    proto2.MergeFrom(proto1)
1306    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
1307    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
1308    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
1309
1310    proto3 = unittest_pb2.TestAllTypes()
1311    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
1312    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
1313    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
1314    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
1315
1316  def testMergeFromAllFields(self):
1317    # With all fields set.
1318    proto1 = unittest_pb2.TestAllTypes()
1319    test_util.SetAllFields(proto1)
1320    proto2 = unittest_pb2.TestAllTypes()
1321    proto2.MergeFrom(proto1)
1322
1323    # Messages should be equal.
1324    self.assertEqual(proto2, proto1)
1325
1326    # Serialized string should be equal too.
1327    string1 = proto1.SerializeToString()
1328    string2 = proto2.SerializeToString()
1329    self.assertEqual(string1, string2)
1330
1331  def testMergeFromExtensionsSingular(self):
1332    proto1 = unittest_pb2.TestAllExtensions()
1333    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1334
1335    proto2 = unittest_pb2.TestAllExtensions()
1336    proto2.MergeFrom(proto1)
1337    self.assertEqual(
1338        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1339
1340  def testMergeFromExtensionsRepeated(self):
1341    proto1 = unittest_pb2.TestAllExtensions()
1342    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1343    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1344
1345    proto2 = unittest_pb2.TestAllExtensions()
1346    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1347    proto2.MergeFrom(proto1)
1348    self.assertEqual(
1349        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1350    self.assertEqual(
1351        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1352    self.assertEqual(
1353        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1354    self.assertEqual(
1355        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1356
1357  def testMergeFromExtensionsNestedMessage(self):
1358    proto1 = unittest_pb2.TestAllExtensions()
1359    ext1 = proto1.Extensions[
1360        unittest_pb2.repeated_nested_message_extension]
1361    m = ext1.add()
1362    m.bb = 222
1363    m = ext1.add()
1364    m.bb = 333
1365
1366    proto2 = unittest_pb2.TestAllExtensions()
1367    ext2 = proto2.Extensions[
1368        unittest_pb2.repeated_nested_message_extension]
1369    m = ext2.add()
1370    m.bb = 111
1371
1372    proto2.MergeFrom(proto1)
1373    ext2 = proto2.Extensions[
1374        unittest_pb2.repeated_nested_message_extension]
1375    self.assertEqual(3, len(ext2))
1376    self.assertEqual(111, ext2[0].bb)
1377    self.assertEqual(222, ext2[1].bb)
1378    self.assertEqual(333, ext2[2].bb)
1379
1380  def testMergeFromBug(self):
1381    message1 = unittest_pb2.TestAllTypes()
1382    message2 = unittest_pb2.TestAllTypes()
1383
1384    # Cause optional_nested_message to be instantiated within message1, even
1385    # though it is not considered to be "present".
1386    message1.optional_nested_message
1387    self.assertFalse(message1.HasField('optional_nested_message'))
1388
1389    # Merge into message2.  This should not instantiate the field is message2.
1390    message2.MergeFrom(message1)
1391    self.assertFalse(message2.HasField('optional_nested_message'))
1392
1393  def testCopyFromSingularField(self):
1394    # Test copy with just a singular field.
1395    proto1 = unittest_pb2.TestAllTypes()
1396    proto1.optional_int32 = 1
1397    proto1.optional_string = 'important-text'
1398
1399    proto2 = unittest_pb2.TestAllTypes()
1400    proto2.optional_string = 'value'
1401
1402    proto2.CopyFrom(proto1)
1403    self.assertEqual(1, proto2.optional_int32)
1404    self.assertEqual('important-text', proto2.optional_string)
1405
1406  def testCopyFromRepeatedField(self):
1407    # Test copy with a repeated field.
1408    proto1 = unittest_pb2.TestAllTypes()
1409    proto1.repeated_int32.append(1)
1410    proto1.repeated_int32.append(2)
1411
1412    proto2 = unittest_pb2.TestAllTypes()
1413    proto2.repeated_int32.append(0)
1414    proto2.CopyFrom(proto1)
1415
1416    self.assertEqual(1, proto2.repeated_int32[0])
1417    self.assertEqual(2, proto2.repeated_int32[1])
1418
1419  def testCopyFromAllFields(self):
1420    # With all fields set.
1421    proto1 = unittest_pb2.TestAllTypes()
1422    test_util.SetAllFields(proto1)
1423    proto2 = unittest_pb2.TestAllTypes()
1424    proto2.CopyFrom(proto1)
1425
1426    # Messages should be equal.
1427    self.assertEqual(proto2, proto1)
1428
1429    # Serialized string should be equal too.
1430    string1 = proto1.SerializeToString()
1431    string2 = proto2.SerializeToString()
1432    self.assertEqual(string1, string2)
1433
1434  def testCopyFromSelf(self):
1435    proto1 = unittest_pb2.TestAllTypes()
1436    proto1.repeated_int32.append(1)
1437    proto1.optional_int32 = 2
1438    proto1.optional_string = 'important-text'
1439
1440    proto1.CopyFrom(proto1)
1441    self.assertEqual(1, proto1.repeated_int32[0])
1442    self.assertEqual(2, proto1.optional_int32)
1443    self.assertEqual('important-text', proto1.optional_string)
1444
1445  def testCopyFromBadType(self):
1446    # The python implementation doesn't raise an exception in this
1447    # case. In theory it should.
1448    if api_implementation.Type() == 'python':
1449      return
1450    proto1 = unittest_pb2.TestAllTypes()
1451    proto2 = unittest_pb2.TestAllExtensions()
1452    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1453
1454  def testClear(self):
1455    proto = unittest_pb2.TestAllTypes()
1456    # C++ implementation does not support lazy fields right now so leave it
1457    # out for now.
1458    if api_implementation.Type() == 'python':
1459      test_util.SetAllFields(proto)
1460    else:
1461      test_util.SetAllNonLazyFields(proto)
1462    # Clear the message.
1463    proto.Clear()
1464    self.assertEquals(proto.ByteSize(), 0)
1465    empty_proto = unittest_pb2.TestAllTypes()
1466    self.assertEquals(proto, empty_proto)
1467
1468    # Test if extensions which were set are cleared.
1469    proto = unittest_pb2.TestAllExtensions()
1470    test_util.SetAllExtensions(proto)
1471    # Clear the message.
1472    proto.Clear()
1473    self.assertEquals(proto.ByteSize(), 0)
1474    empty_proto = unittest_pb2.TestAllExtensions()
1475    self.assertEquals(proto, empty_proto)
1476
1477  def testDisconnectingBeforeClear(self):
1478    proto = unittest_pb2.TestAllTypes()
1479    nested = proto.optional_nested_message
1480    proto.Clear()
1481    self.assertTrue(nested is not proto.optional_nested_message)
1482    nested.bb = 23
1483    self.assertTrue(not proto.HasField('optional_nested_message'))
1484    self.assertEqual(0, proto.optional_nested_message.bb)
1485
1486    proto = unittest_pb2.TestAllTypes()
1487    nested = proto.optional_nested_message
1488    nested.bb = 5
1489    foreign = proto.optional_foreign_message
1490    foreign.c = 6
1491
1492    proto.Clear()
1493    self.assertTrue(nested is not proto.optional_nested_message)
1494    self.assertTrue(foreign is not proto.optional_foreign_message)
1495    self.assertEqual(5, nested.bb)
1496    self.assertEqual(6, foreign.c)
1497    nested.bb = 15
1498    foreign.c = 16
1499    self.assertTrue(not proto.HasField('optional_nested_message'))
1500    self.assertEqual(0, proto.optional_nested_message.bb)
1501    self.assertTrue(not proto.HasField('optional_foreign_message'))
1502    self.assertEqual(0, proto.optional_foreign_message.c)
1503
1504  def assertInitialized(self, proto):
1505    self.assertTrue(proto.IsInitialized())
1506    # Neither method should raise an exception.
1507    proto.SerializeToString()
1508    proto.SerializePartialToString()
1509
1510  def assertNotInitialized(self, proto):
1511    self.assertFalse(proto.IsInitialized())
1512    self.assertRaises(message.EncodeError, proto.SerializeToString)
1513    # "Partial" serialization doesn't care if message is uninitialized.
1514    proto.SerializePartialToString()
1515
1516  def testIsInitialized(self):
1517    # Trivial cases - all optional fields and extensions.
1518    proto = unittest_pb2.TestAllTypes()
1519    self.assertInitialized(proto)
1520    proto = unittest_pb2.TestAllExtensions()
1521    self.assertInitialized(proto)
1522
1523    # The case of uninitialized required fields.
1524    proto = unittest_pb2.TestRequired()
1525    self.assertNotInitialized(proto)
1526    proto.a = proto.b = proto.c = 2
1527    self.assertInitialized(proto)
1528
1529    # The case of uninitialized submessage.
1530    proto = unittest_pb2.TestRequiredForeign()
1531    self.assertInitialized(proto)
1532    proto.optional_message.a = 1
1533    self.assertNotInitialized(proto)
1534    proto.optional_message.b = 0
1535    proto.optional_message.c = 0
1536    self.assertInitialized(proto)
1537
1538    # Uninitialized repeated submessage.
1539    message1 = proto.repeated_message.add()
1540    self.assertNotInitialized(proto)
1541    message1.a = message1.b = message1.c = 0
1542    self.assertInitialized(proto)
1543
1544    # Uninitialized repeated group in an extension.
1545    proto = unittest_pb2.TestAllExtensions()
1546    extension = unittest_pb2.TestRequired.multi
1547    message1 = proto.Extensions[extension].add()
1548    message2 = proto.Extensions[extension].add()
1549    self.assertNotInitialized(proto)
1550    message1.a = 1
1551    message1.b = 1
1552    message1.c = 1
1553    self.assertNotInitialized(proto)
1554    message2.a = 2
1555    message2.b = 2
1556    message2.c = 2
1557    self.assertInitialized(proto)
1558
1559    # Uninitialized nonrepeated message in an extension.
1560    proto = unittest_pb2.TestAllExtensions()
1561    extension = unittest_pb2.TestRequired.single
1562    proto.Extensions[extension].a = 1
1563    self.assertNotInitialized(proto)
1564    proto.Extensions[extension].b = 2
1565    proto.Extensions[extension].c = 3
1566    self.assertInitialized(proto)
1567
1568    # Try passing an errors list.
1569    errors = []
1570    proto = unittest_pb2.TestRequired()
1571    self.assertFalse(proto.IsInitialized(errors))
1572    self.assertEqual(errors, ['a', 'b', 'c'])
1573
1574  def testStringUTF8Encoding(self):
1575    proto = unittest_pb2.TestAllTypes()
1576
1577    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
1578    self.assertRaises(TypeError,
1579                      setattr, proto, 'optional_bytes', u'unicode object')
1580
1581    # Check that the default value is of python's 'unicode' type.
1582    self.assertEqual(type(proto.optional_string), unicode)
1583
1584    proto.optional_string = unicode('Testing')
1585    self.assertEqual(proto.optional_string, str('Testing'))
1586
1587    # Assign a value of type 'str' which can be encoded in UTF-8.
1588    proto.optional_string = str('Testing')
1589    self.assertEqual(proto.optional_string, unicode('Testing'))
1590
1591    if api_implementation.Type() == 'python':
1592      # Values of type 'str' are also accepted as long as they can be
1593      # encoded in UTF-8.
1594      self.assertEqual(type(proto.optional_string), str)
1595
1596    # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
1597    self.assertRaises(ValueError,
1598                      setattr, proto, 'optional_string', str('a\x80a'))
1599    # Assign a 'str' object which contains a UTF-8 encoded string.
1600    self.assertRaises(ValueError,
1601                      setattr, proto, 'optional_string', 'Тест')
1602    # No exception thrown.
1603    proto.optional_string = 'abc'
1604
1605  def testStringUTF8Serialization(self):
1606    proto = unittest_mset_pb2.TestMessageSet()
1607    extension_message = unittest_mset_pb2.TestMessageSetExtension2
1608    extension = extension_message.message_set_extension
1609
1610    test_utf8 = u'Тест'
1611    test_utf8_bytes = test_utf8.encode('utf-8')
1612
1613    # 'Test' in another language, using UTF-8 charset.
1614    proto.Extensions[extension].str = test_utf8
1615
1616    # Serialize using the MessageSet wire format (this is specified in the
1617    # .proto file).
1618    serialized = proto.SerializeToString()
1619
1620    # Check byte size.
1621    self.assertEqual(proto.ByteSize(), len(serialized))
1622
1623    raw = unittest_mset_pb2.RawMessageSet()
1624    raw.MergeFromString(serialized)
1625
1626    message2 = unittest_mset_pb2.TestMessageSetExtension2()
1627
1628    self.assertEqual(1, len(raw.item))
1629    # Check that the type_id is the same as the tag ID in the .proto file.
1630    self.assertEqual(raw.item[0].type_id, 1547769)
1631
1632    # Check the actual bytes on the wire.
1633    self.assertTrue(
1634        raw.item[0].message.endswith(test_utf8_bytes))
1635    message2.MergeFromString(raw.item[0].message)
1636
1637    self.assertEqual(type(message2.str), unicode)
1638    self.assertEqual(message2.str, test_utf8)
1639
1640    # The pure Python API throws an exception on MergeFromString(),
1641    # if any of the string fields of the message can't be UTF-8 decoded.
1642    # The C++ implementation of the API has no way to check that on
1643    # MergeFromString and thus has no way to throw the exception.
1644    #
1645    # The pure Python API always returns objects of type 'unicode' (UTF-8
1646    # encoded), or 'str' (in 7 bit ASCII).
1647    bytes = raw.item[0].message.replace(
1648        test_utf8_bytes, len(test_utf8_bytes) * '\xff')
1649
1650    unicode_decode_failed = False
1651    try:
1652      message2.MergeFromString(bytes)
1653    except UnicodeDecodeError as e:
1654      unicode_decode_failed = True
1655    string_field = message2.str
1656    self.assertTrue(unicode_decode_failed or type(string_field) == str)
1657
1658  def testEmptyNestedMessage(self):
1659    proto = unittest_pb2.TestAllTypes()
1660    proto.optional_nested_message.MergeFrom(
1661        unittest_pb2.TestAllTypes.NestedMessage())
1662    self.assertTrue(proto.HasField('optional_nested_message'))
1663
1664    proto = unittest_pb2.TestAllTypes()
1665    proto.optional_nested_message.CopyFrom(
1666        unittest_pb2.TestAllTypes.NestedMessage())
1667    self.assertTrue(proto.HasField('optional_nested_message'))
1668
1669    proto = unittest_pb2.TestAllTypes()
1670    proto.optional_nested_message.MergeFromString('')
1671    self.assertTrue(proto.HasField('optional_nested_message'))
1672
1673    proto = unittest_pb2.TestAllTypes()
1674    proto.optional_nested_message.ParseFromString('')
1675    self.assertTrue(proto.HasField('optional_nested_message'))
1676
1677    serialized = proto.SerializeToString()
1678    proto2 = unittest_pb2.TestAllTypes()
1679    proto2.MergeFromString(serialized)
1680    self.assertTrue(proto2.HasField('optional_nested_message'))
1681
1682  def testSetInParent(self):
1683    proto = unittest_pb2.TestAllTypes()
1684    self.assertFalse(proto.HasField('optionalgroup'))
1685    proto.optionalgroup.SetInParent()
1686    self.assertTrue(proto.HasField('optionalgroup'))
1687
1688
1689#  Since we had so many tests for protocol buffer equality, we broke these out
1690#  into separate TestCase classes.
1691
1692
1693class TestAllTypesEqualityTest(unittest.TestCase):
1694
1695  def setUp(self):
1696    self.first_proto = unittest_pb2.TestAllTypes()
1697    self.second_proto = unittest_pb2.TestAllTypes()
1698
1699  def testNotHashable(self):
1700    self.assertRaises(TypeError, hash, self.first_proto)
1701
1702  def testSelfEquality(self):
1703    self.assertEqual(self.first_proto, self.first_proto)
1704
1705  def testEmptyProtosEqual(self):
1706    self.assertEqual(self.first_proto, self.second_proto)
1707
1708
1709class FullProtosEqualityTest(unittest.TestCase):
1710
1711  """Equality tests using completely-full protos as a starting point."""
1712
1713  def setUp(self):
1714    self.first_proto = unittest_pb2.TestAllTypes()
1715    self.second_proto = unittest_pb2.TestAllTypes()
1716    test_util.SetAllFields(self.first_proto)
1717    test_util.SetAllFields(self.second_proto)
1718
1719  def testNotHashable(self):
1720    self.assertRaises(TypeError, hash, self.first_proto)
1721
1722  def testNoneNotEqual(self):
1723    self.assertNotEqual(self.first_proto, None)
1724    self.assertNotEqual(None, self.second_proto)
1725
1726  def testNotEqualToOtherMessage(self):
1727    third_proto = unittest_pb2.TestRequired()
1728    self.assertNotEqual(self.first_proto, third_proto)
1729    self.assertNotEqual(third_proto, self.second_proto)
1730
1731  def testAllFieldsFilledEquality(self):
1732    self.assertEqual(self.first_proto, self.second_proto)
1733
1734  def testNonRepeatedScalar(self):
1735    # Nonrepeated scalar field change should cause inequality.
1736    self.first_proto.optional_int32 += 1
1737    self.assertNotEqual(self.first_proto, self.second_proto)
1738    # ...as should clearing a field.
1739    self.first_proto.ClearField('optional_int32')
1740    self.assertNotEqual(self.first_proto, self.second_proto)
1741
1742  def testNonRepeatedComposite(self):
1743    # Change a nonrepeated composite field.
1744    self.first_proto.optional_nested_message.bb += 1
1745    self.assertNotEqual(self.first_proto, self.second_proto)
1746    self.first_proto.optional_nested_message.bb -= 1
1747    self.assertEqual(self.first_proto, self.second_proto)
1748    # Clear a field in the nested message.
1749    self.first_proto.optional_nested_message.ClearField('bb')
1750    self.assertNotEqual(self.first_proto, self.second_proto)
1751    self.first_proto.optional_nested_message.bb = (
1752        self.second_proto.optional_nested_message.bb)
1753    self.assertEqual(self.first_proto, self.second_proto)
1754    # Remove the nested message entirely.
1755    self.first_proto.ClearField('optional_nested_message')
1756    self.assertNotEqual(self.first_proto, self.second_proto)
1757
1758  def testRepeatedScalar(self):
1759    # Change a repeated scalar field.
1760    self.first_proto.repeated_int32.append(5)
1761    self.assertNotEqual(self.first_proto, self.second_proto)
1762    self.first_proto.ClearField('repeated_int32')
1763    self.assertNotEqual(self.first_proto, self.second_proto)
1764
1765  def testRepeatedComposite(self):
1766    # Change value within a repeated composite field.
1767    self.first_proto.repeated_nested_message[0].bb += 1
1768    self.assertNotEqual(self.first_proto, self.second_proto)
1769    self.first_proto.repeated_nested_message[0].bb -= 1
1770    self.assertEqual(self.first_proto, self.second_proto)
1771    # Add a value to a repeated composite field.
1772    self.first_proto.repeated_nested_message.add()
1773    self.assertNotEqual(self.first_proto, self.second_proto)
1774    self.second_proto.repeated_nested_message.add()
1775    self.assertEqual(self.first_proto, self.second_proto)
1776
1777  def testNonRepeatedScalarHasBits(self):
1778    # Ensure that we test "has" bits as well as value for
1779    # nonrepeated scalar field.
1780    self.first_proto.ClearField('optional_int32')
1781    self.second_proto.optional_int32 = 0
1782    self.assertNotEqual(self.first_proto, self.second_proto)
1783
1784  def testNonRepeatedCompositeHasBits(self):
1785    # Ensure that we test "has" bits as well as value for
1786    # nonrepeated composite field.
1787    self.first_proto.ClearField('optional_nested_message')
1788    self.second_proto.optional_nested_message.ClearField('bb')
1789    self.assertNotEqual(self.first_proto, self.second_proto)
1790    self.first_proto.optional_nested_message.bb = 0
1791    self.first_proto.optional_nested_message.ClearField('bb')
1792    self.assertEqual(self.first_proto, self.second_proto)
1793
1794
1795class ExtensionEqualityTest(unittest.TestCase):
1796
1797  def testExtensionEquality(self):
1798    first_proto = unittest_pb2.TestAllExtensions()
1799    second_proto = unittest_pb2.TestAllExtensions()
1800    self.assertEqual(first_proto, second_proto)
1801    test_util.SetAllExtensions(first_proto)
1802    self.assertNotEqual(first_proto, second_proto)
1803    test_util.SetAllExtensions(second_proto)
1804    self.assertEqual(first_proto, second_proto)
1805
1806    # Ensure that we check value equality.
1807    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
1808    self.assertNotEqual(first_proto, second_proto)
1809    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
1810    self.assertEqual(first_proto, second_proto)
1811
1812    # Ensure that we also look at "has" bits.
1813    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
1814    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1815    self.assertNotEqual(first_proto, second_proto)
1816    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1817    self.assertEqual(first_proto, second_proto)
1818
1819    # Ensure that differences in cached values
1820    # don't matter if "has" bits are both false.
1821    first_proto = unittest_pb2.TestAllExtensions()
1822    second_proto = unittest_pb2.TestAllExtensions()
1823    self.assertEqual(
1824        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
1825    self.assertEqual(first_proto, second_proto)
1826
1827
1828class MutualRecursionEqualityTest(unittest.TestCase):
1829
1830  def testEqualityWithMutualRecursion(self):
1831    first_proto = unittest_pb2.TestMutualRecursionA()
1832    second_proto = unittest_pb2.TestMutualRecursionA()
1833    self.assertEqual(first_proto, second_proto)
1834    first_proto.bb.a.bb.optional_int32 = 23
1835    self.assertNotEqual(first_proto, second_proto)
1836    second_proto.bb.a.bb.optional_int32 = 23
1837    self.assertEqual(first_proto, second_proto)
1838
1839
1840class ByteSizeTest(unittest.TestCase):
1841
1842  def setUp(self):
1843    self.proto = unittest_pb2.TestAllTypes()
1844    self.extended_proto = more_extensions_pb2.ExtendedMessage()
1845    self.packed_proto = unittest_pb2.TestPackedTypes()
1846    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
1847
1848  def Size(self):
1849    return self.proto.ByteSize()
1850
1851  def testEmptyMessage(self):
1852    self.assertEqual(0, self.proto.ByteSize())
1853
1854  def testSizedOnKwargs(self):
1855    # Use a separate message to ensure testing right after creation.
1856    proto = unittest_pb2.TestAllTypes()
1857    self.assertEqual(0, proto.ByteSize())
1858    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
1859    # One byte for the tag, one to encode varint 1.
1860    self.assertEqual(2, proto_kwargs.ByteSize())
1861
1862  def testVarints(self):
1863    def Test(i, expected_varint_size):
1864      self.proto.Clear()
1865      self.proto.optional_int64 = i
1866      # Add one to the varint size for the tag info
1867      # for tag 1.
1868      self.assertEqual(expected_varint_size + 1, self.Size())
1869    Test(0, 1)
1870    Test(1, 1)
1871    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
1872      Test((1 << i) - 1, num_bytes)
1873    Test(-1, 10)
1874    Test(-2, 10)
1875    Test(-(1 << 63), 10)
1876
1877  def testStrings(self):
1878    self.proto.optional_string = ''
1879    # Need one byte for tag info (tag #14), and one byte for length.
1880    self.assertEqual(2, self.Size())
1881
1882    self.proto.optional_string = 'abc'
1883    # Need one byte for tag info (tag #14), and one byte for length.
1884    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
1885
1886    self.proto.optional_string = 'x' * 128
1887    # Need one byte for tag info (tag #14), and TWO bytes for length.
1888    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
1889
1890  def testOtherNumerics(self):
1891    self.proto.optional_fixed32 = 1234
1892    # One byte for tag and 4 bytes for fixed32.
1893    self.assertEqual(5, self.Size())
1894    self.proto = unittest_pb2.TestAllTypes()
1895
1896    self.proto.optional_fixed64 = 1234
1897    # One byte for tag and 8 bytes for fixed64.
1898    self.assertEqual(9, self.Size())
1899    self.proto = unittest_pb2.TestAllTypes()
1900
1901    self.proto.optional_float = 1.234
1902    # One byte for tag and 4 bytes for float.
1903    self.assertEqual(5, self.Size())
1904    self.proto = unittest_pb2.TestAllTypes()
1905
1906    self.proto.optional_double = 1.234
1907    # One byte for tag and 8 bytes for float.
1908    self.assertEqual(9, self.Size())
1909    self.proto = unittest_pb2.TestAllTypes()
1910
1911    self.proto.optional_sint32 = 64
1912    # One byte for tag and 2 bytes for zig-zag-encoded 64.
1913    self.assertEqual(3, self.Size())
1914    self.proto = unittest_pb2.TestAllTypes()
1915
1916  def testComposites(self):
1917    # 3 bytes.
1918    self.proto.optional_nested_message.bb = (1 << 14)
1919    # Plus one byte for bb tag.
1920    # Plus 1 byte for optional_nested_message serialized size.
1921    # Plus two bytes for optional_nested_message tag.
1922    self.assertEqual(3 + 1 + 1 + 2, self.Size())
1923
1924  def testGroups(self):
1925    # 4 bytes.
1926    self.proto.optionalgroup.a = (1 << 21)
1927    # Plus two bytes for |a| tag.
1928    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
1929    self.assertEqual(4 + 2 + 2*2, self.Size())
1930
1931  def testRepeatedScalars(self):
1932    self.proto.repeated_int32.append(10)  # 1 byte.
1933    self.proto.repeated_int32.append(128)  # 2 bytes.
1934    # Also need 2 bytes for each entry for tag.
1935    self.assertEqual(1 + 2 + 2*2, self.Size())
1936
1937  def testRepeatedScalarsExtend(self):
1938    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
1939    # Also need 2 bytes for each entry for tag.
1940    self.assertEqual(1 + 2 + 2*2, self.Size())
1941
1942  def testRepeatedScalarsRemove(self):
1943    self.proto.repeated_int32.append(10)  # 1 byte.
1944    self.proto.repeated_int32.append(128)  # 2 bytes.
1945    # Also need 2 bytes for each entry for tag.
1946    self.assertEqual(1 + 2 + 2*2, self.Size())
1947    self.proto.repeated_int32.remove(128)
1948    self.assertEqual(1 + 2, self.Size())
1949
1950  def testRepeatedComposites(self):
1951    # Empty message.  2 bytes tag plus 1 byte length.
1952    foreign_message_0 = self.proto.repeated_nested_message.add()
1953    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1954    foreign_message_1 = self.proto.repeated_nested_message.add()
1955    foreign_message_1.bb = 7
1956    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
1957
1958  def testRepeatedCompositesDelete(self):
1959    # Empty message.  2 bytes tag plus 1 byte length.
1960    foreign_message_0 = self.proto.repeated_nested_message.add()
1961    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1962    foreign_message_1 = self.proto.repeated_nested_message.add()
1963    foreign_message_1.bb = 9
1964    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
1965
1966    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1967    del self.proto.repeated_nested_message[0]
1968    self.assertEqual(2 + 1 + 1 + 1, self.Size())
1969
1970    # Now add a new message.
1971    foreign_message_2 = self.proto.repeated_nested_message.add()
1972    foreign_message_2.bb = 12
1973
1974    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1975    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1976    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
1977
1978    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
1979    del self.proto.repeated_nested_message[1]
1980    self.assertEqual(2 + 1 + 1 + 1, self.Size())
1981
1982    del self.proto.repeated_nested_message[0]
1983    self.assertEqual(0, self.Size())
1984
1985  def testRepeatedGroups(self):
1986    # 2-byte START_GROUP plus 2-byte END_GROUP.
1987    group_0 = self.proto.repeatedgroup.add()
1988    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
1989    # plus 2-byte END_GROUP.
1990    group_1 = self.proto.repeatedgroup.add()
1991    group_1.a =  7
1992    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
1993
1994  def testExtensions(self):
1995    proto = unittest_pb2.TestAllExtensions()
1996    self.assertEqual(0, proto.ByteSize())
1997    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
1998    proto.Extensions[extension] = 23
1999    # 1 byte for tag, 1 byte for value.
2000    self.assertEqual(2, proto.ByteSize())
2001
2002  def testCacheInvalidationForNonrepeatedScalar(self):
2003    # Test non-extension.
2004    self.proto.optional_int32 = 1
2005    self.assertEqual(2, self.proto.ByteSize())
2006    self.proto.optional_int32 = 128
2007    self.assertEqual(3, self.proto.ByteSize())
2008    self.proto.ClearField('optional_int32')
2009    self.assertEqual(0, self.proto.ByteSize())
2010
2011    # Test within extension.
2012    extension = more_extensions_pb2.optional_int_extension
2013    self.extended_proto.Extensions[extension] = 1
2014    self.assertEqual(2, self.extended_proto.ByteSize())
2015    self.extended_proto.Extensions[extension] = 128
2016    self.assertEqual(3, self.extended_proto.ByteSize())
2017    self.extended_proto.ClearExtension(extension)
2018    self.assertEqual(0, self.extended_proto.ByteSize())
2019
2020  def testCacheInvalidationForRepeatedScalar(self):
2021    # Test non-extension.
2022    self.proto.repeated_int32.append(1)
2023    self.assertEqual(3, self.proto.ByteSize())
2024    self.proto.repeated_int32.append(1)
2025    self.assertEqual(6, self.proto.ByteSize())
2026    self.proto.repeated_int32[1] = 128
2027    self.assertEqual(7, self.proto.ByteSize())
2028    self.proto.ClearField('repeated_int32')
2029    self.assertEqual(0, self.proto.ByteSize())
2030
2031    # Test within extension.
2032    extension = more_extensions_pb2.repeated_int_extension
2033    repeated = self.extended_proto.Extensions[extension]
2034    repeated.append(1)
2035    self.assertEqual(2, self.extended_proto.ByteSize())
2036    repeated.append(1)
2037    self.assertEqual(4, self.extended_proto.ByteSize())
2038    repeated[1] = 128
2039    self.assertEqual(5, self.extended_proto.ByteSize())
2040    self.extended_proto.ClearExtension(extension)
2041    self.assertEqual(0, self.extended_proto.ByteSize())
2042
2043  def testCacheInvalidationForNonrepeatedMessage(self):
2044    # Test non-extension.
2045    self.proto.optional_foreign_message.c = 1
2046    self.assertEqual(5, self.proto.ByteSize())
2047    self.proto.optional_foreign_message.c = 128
2048    self.assertEqual(6, self.proto.ByteSize())
2049    self.proto.optional_foreign_message.ClearField('c')
2050    self.assertEqual(3, self.proto.ByteSize())
2051    self.proto.ClearField('optional_foreign_message')
2052    self.assertEqual(0, self.proto.ByteSize())
2053
2054    if api_implementation.Type() == 'python':
2055      # This is only possible in pure-Python implementation of the API.
2056      child = self.proto.optional_foreign_message
2057      self.proto.ClearField('optional_foreign_message')
2058      child.c = 128
2059      self.assertEqual(0, self.proto.ByteSize())
2060
2061    # Test within extension.
2062    extension = more_extensions_pb2.optional_message_extension
2063    child = self.extended_proto.Extensions[extension]
2064    self.assertEqual(0, self.extended_proto.ByteSize())
2065    child.foreign_message_int = 1
2066    self.assertEqual(4, self.extended_proto.ByteSize())
2067    child.foreign_message_int = 128
2068    self.assertEqual(5, self.extended_proto.ByteSize())
2069    self.extended_proto.ClearExtension(extension)
2070    self.assertEqual(0, self.extended_proto.ByteSize())
2071
2072  def testCacheInvalidationForRepeatedMessage(self):
2073    # Test non-extension.
2074    child0 = self.proto.repeated_foreign_message.add()
2075    self.assertEqual(3, self.proto.ByteSize())
2076    self.proto.repeated_foreign_message.add()
2077    self.assertEqual(6, self.proto.ByteSize())
2078    child0.c = 1
2079    self.assertEqual(8, self.proto.ByteSize())
2080    self.proto.ClearField('repeated_foreign_message')
2081    self.assertEqual(0, self.proto.ByteSize())
2082
2083    # Test within extension.
2084    extension = more_extensions_pb2.repeated_message_extension
2085    child_list = self.extended_proto.Extensions[extension]
2086    child0 = child_list.add()
2087    self.assertEqual(2, self.extended_proto.ByteSize())
2088    child_list.add()
2089    self.assertEqual(4, self.extended_proto.ByteSize())
2090    child0.foreign_message_int = 1
2091    self.assertEqual(6, self.extended_proto.ByteSize())
2092    child0.ClearField('foreign_message_int')
2093    self.assertEqual(4, self.extended_proto.ByteSize())
2094    self.extended_proto.ClearExtension(extension)
2095    self.assertEqual(0, self.extended_proto.ByteSize())
2096
2097  def testPackedRepeatedScalars(self):
2098    self.assertEqual(0, self.packed_proto.ByteSize())
2099
2100    self.packed_proto.packed_int32.append(10)   # 1 byte.
2101    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2102    # The tag is 2 bytes (the field number is 90), and the varint
2103    # storing the length is 1 byte.
2104    int_size = 1 + 2 + 3
2105    self.assertEqual(int_size, self.packed_proto.ByteSize())
2106
2107    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2108    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2109    # 2 more tag bytes, 1 more length byte.
2110    double_size = 8 + 8 + 3
2111    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2112
2113    self.packed_proto.ClearField('packed_int32')
2114    self.assertEqual(double_size, self.packed_proto.ByteSize())
2115
2116  def testPackedExtensions(self):
2117    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2118    extension = self.packed_extended_proto.Extensions[
2119        unittest_pb2.packed_fixed32_extension]
2120    extension.extend([1, 2, 3, 4])   # 16 bytes
2121    # Tag is 3 bytes.
2122    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2123
2124
2125# Issues to be sure to cover include:
2126#   * Handling of unrecognized tags ("uninterpreted_bytes").
2127#   * Handling of MessageSets.
2128#   * Consistent ordering of tags in the wire format,
2129#     including ordering between extensions and non-extension
2130#     fields.
2131#   * Consistent serialization of negative numbers, especially
2132#     negative int32s.
2133#   * Handling of empty submessages (with and without "has"
2134#     bits set).
2135
2136class SerializationTest(unittest.TestCase):
2137
2138  def testSerializeEmtpyMessage(self):
2139    first_proto = unittest_pb2.TestAllTypes()
2140    second_proto = unittest_pb2.TestAllTypes()
2141    serialized = first_proto.SerializeToString()
2142    self.assertEqual(first_proto.ByteSize(), len(serialized))
2143    second_proto.MergeFromString(serialized)
2144    self.assertEqual(first_proto, second_proto)
2145
2146  def testSerializeAllFields(self):
2147    first_proto = unittest_pb2.TestAllTypes()
2148    second_proto = unittest_pb2.TestAllTypes()
2149    test_util.SetAllFields(first_proto)
2150    serialized = first_proto.SerializeToString()
2151    self.assertEqual(first_proto.ByteSize(), len(serialized))
2152    second_proto.MergeFromString(serialized)
2153    self.assertEqual(first_proto, second_proto)
2154
2155  def testSerializeAllExtensions(self):
2156    first_proto = unittest_pb2.TestAllExtensions()
2157    second_proto = unittest_pb2.TestAllExtensions()
2158    test_util.SetAllExtensions(first_proto)
2159    serialized = first_proto.SerializeToString()
2160    second_proto.MergeFromString(serialized)
2161    self.assertEqual(first_proto, second_proto)
2162
2163  def testSerializeNegativeValues(self):
2164    first_proto = unittest_pb2.TestAllTypes()
2165
2166    first_proto.optional_int32 = -1
2167    first_proto.optional_int64 = -(2 << 40)
2168    first_proto.optional_sint32 = -3
2169    first_proto.optional_sint64 = -(4 << 40)
2170    first_proto.optional_sfixed32 = -5
2171    first_proto.optional_sfixed64 = -(6 << 40)
2172
2173    second_proto = unittest_pb2.TestAllTypes.FromString(
2174        first_proto.SerializeToString())
2175
2176    self.assertEqual(first_proto, second_proto)
2177
2178  def testParseTruncated(self):
2179    # This test is only applicable for the Python implementation of the API.
2180    if api_implementation.Type() != 'python':
2181      return
2182
2183    first_proto = unittest_pb2.TestAllTypes()
2184    test_util.SetAllFields(first_proto)
2185    serialized = first_proto.SerializeToString()
2186
2187    for truncation_point in xrange(len(serialized) + 1):
2188      try:
2189        second_proto = unittest_pb2.TestAllTypes()
2190        unknown_fields = unittest_pb2.TestEmptyMessage()
2191        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2192        # If we didn't raise an error then we read exactly the amount expected.
2193        self.assertEqual(truncation_point, pos)
2194
2195        # Parsing to unknown fields should not throw if parsing to known fields
2196        # did not.
2197        try:
2198          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2199          self.assertEqual(truncation_point, pos2)
2200        except message.DecodeError:
2201          self.fail('Parsing unknown fields failed when parsing known fields '
2202                    'did not.')
2203      except message.DecodeError:
2204        # Parsing unknown fields should also fail.
2205        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2206                          serialized, 0, truncation_point)
2207
2208  def testCanonicalSerializationOrder(self):
2209    proto = more_messages_pb2.OutOfOrderFields()
2210    # These are also their tag numbers.  Even though we're setting these in
2211    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2212    # file, they should nonetheless be serialized in tag order.
2213    proto.optional_sint32 = 5
2214    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2215    proto.optional_uint32 = 3
2216    proto.Extensions[more_messages_pb2.optional_int64] = 2
2217    proto.optional_int32 = 1
2218    serialized = proto.SerializeToString()
2219    self.assertEqual(proto.ByteSize(), len(serialized))
2220    d = _MiniDecoder(serialized)
2221    ReadTag = d.ReadFieldNumberAndWireType
2222    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2223    self.assertEqual(1, d.ReadInt32())
2224    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2225    self.assertEqual(2, d.ReadInt64())
2226    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2227    self.assertEqual(3, d.ReadUInt32())
2228    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2229    self.assertEqual(4, d.ReadUInt64())
2230    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2231    self.assertEqual(5, d.ReadSInt32())
2232
2233  def testCanonicalSerializationOrderSameAsCpp(self):
2234    # Copy of the same test we use for C++.
2235    proto = unittest_pb2.TestFieldOrderings()
2236    test_util.SetAllFieldsAndExtensions(proto)
2237    serialized = proto.SerializeToString()
2238    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2239
2240  def testMergeFromStringWhenFieldsAlreadySet(self):
2241    first_proto = unittest_pb2.TestAllTypes()
2242    first_proto.repeated_string.append('foobar')
2243    first_proto.optional_int32 = 23
2244    first_proto.optional_nested_message.bb = 42
2245    serialized = first_proto.SerializeToString()
2246
2247    second_proto = unittest_pb2.TestAllTypes()
2248    second_proto.repeated_string.append('baz')
2249    second_proto.optional_int32 = 100
2250    second_proto.optional_nested_message.bb = 999
2251
2252    second_proto.MergeFromString(serialized)
2253    # Ensure that we append to repeated fields.
2254    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2255    # Ensure that we overwrite nonrepeatd scalars.
2256    self.assertEqual(23, second_proto.optional_int32)
2257    # Ensure that we recursively call MergeFromString() on
2258    # submessages.
2259    self.assertEqual(42, second_proto.optional_nested_message.bb)
2260
2261  def testMessageSetWireFormat(self):
2262    proto = unittest_mset_pb2.TestMessageSet()
2263    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2264    extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
2265    extension1 = extension_message1.message_set_extension
2266    extension2 = extension_message2.message_set_extension
2267    proto.Extensions[extension1].i = 123
2268    proto.Extensions[extension2].str = 'foo'
2269
2270    # Serialize using the MessageSet wire format (this is specified in the
2271    # .proto file).
2272    serialized = proto.SerializeToString()
2273
2274    raw = unittest_mset_pb2.RawMessageSet()
2275    self.assertEqual(False,
2276                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2277    raw.MergeFromString(serialized)
2278    self.assertEqual(2, len(raw.item))
2279
2280    message1 = unittest_mset_pb2.TestMessageSetExtension1()
2281    message1.MergeFromString(raw.item[0].message)
2282    self.assertEqual(123, message1.i)
2283
2284    message2 = unittest_mset_pb2.TestMessageSetExtension2()
2285    message2.MergeFromString(raw.item[1].message)
2286    self.assertEqual('foo', message2.str)
2287
2288    # Deserialize using the MessageSet wire format.
2289    proto2 = unittest_mset_pb2.TestMessageSet()
2290    proto2.MergeFromString(serialized)
2291    self.assertEqual(123, proto2.Extensions[extension1].i)
2292    self.assertEqual('foo', proto2.Extensions[extension2].str)
2293
2294    # Check byte size.
2295    self.assertEqual(proto2.ByteSize(), len(serialized))
2296    self.assertEqual(proto.ByteSize(), len(serialized))
2297
2298  def testMessageSetWireFormatUnknownExtension(self):
2299    # Create a message using the message set wire format with an unknown
2300    # message.
2301    raw = unittest_mset_pb2.RawMessageSet()
2302
2303    # Add an item.
2304    item = raw.item.add()
2305    item.type_id = 1545008
2306    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2307    message1 = unittest_mset_pb2.TestMessageSetExtension1()
2308    message1.i = 12345
2309    item.message = message1.SerializeToString()
2310
2311    # Add a second, unknown extension.
2312    item = raw.item.add()
2313    item.type_id = 1545009
2314    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2315    message1 = unittest_mset_pb2.TestMessageSetExtension1()
2316    message1.i = 12346
2317    item.message = message1.SerializeToString()
2318
2319    # Add another unknown extension.
2320    item = raw.item.add()
2321    item.type_id = 1545010
2322    message1 = unittest_mset_pb2.TestMessageSetExtension2()
2323    message1.str = 'foo'
2324    item.message = message1.SerializeToString()
2325
2326    serialized = raw.SerializeToString()
2327
2328    # Parse message using the message set wire format.
2329    proto = unittest_mset_pb2.TestMessageSet()
2330    proto.MergeFromString(serialized)
2331
2332    # Check that the message parsed well.
2333    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2334    extension1 = extension_message1.message_set_extension
2335    self.assertEquals(12345, proto.Extensions[extension1].i)
2336
2337  def testUnknownFields(self):
2338    proto = unittest_pb2.TestAllTypes()
2339    test_util.SetAllFields(proto)
2340
2341    serialized = proto.SerializeToString()
2342
2343    # The empty message should be parsable with all of the fields
2344    # unknown.
2345    proto2 = unittest_pb2.TestEmptyMessage()
2346
2347    # Parsing this message should succeed.
2348    proto2.MergeFromString(serialized)
2349
2350    # Now test with a int64 field set.
2351    proto = unittest_pb2.TestAllTypes()
2352    proto.optional_int64 = 0x0fffffffffffffff
2353    serialized = proto.SerializeToString()
2354    # The empty message should be parsable with all of the fields
2355    # unknown.
2356    proto2 = unittest_pb2.TestEmptyMessage()
2357    # Parsing this message should succeed.
2358    proto2.MergeFromString(serialized)
2359
2360  def _CheckRaises(self, exc_class, callable_obj, exception):
2361    """This method checks if the excpetion type and message are as expected."""
2362    try:
2363      callable_obj()
2364    except exc_class as ex:
2365      # Check if the exception message is the right one.
2366      self.assertEqual(exception, str(ex))
2367      return
2368    else:
2369      raise self.failureException('%s not raised' % str(exc_class))
2370
2371  def testSerializeUninitialized(self):
2372    proto = unittest_pb2.TestRequired()
2373    self._CheckRaises(
2374        message.EncodeError,
2375        proto.SerializeToString,
2376        'Message protobuf_unittest.TestRequired is missing required fields: '
2377        'a,b,c')
2378    # Shouldn't raise exceptions.
2379    partial = proto.SerializePartialToString()
2380
2381    proto2 = unittest_pb2.TestRequired()
2382    self.assertFalse(proto2.HasField('a'))
2383    # proto2 ParseFromString does not check that required fields are set.
2384    proto2.ParseFromString(partial)
2385    self.assertFalse(proto2.HasField('a'))
2386
2387    proto.a = 1
2388    self._CheckRaises(
2389        message.EncodeError,
2390        proto.SerializeToString,
2391        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2392    # Shouldn't raise exceptions.
2393    partial = proto.SerializePartialToString()
2394
2395    proto.b = 2
2396    self._CheckRaises(
2397        message.EncodeError,
2398        proto.SerializeToString,
2399        'Message protobuf_unittest.TestRequired is missing required fields: c')
2400    # Shouldn't raise exceptions.
2401    partial = proto.SerializePartialToString()
2402
2403    proto.c = 3
2404    serialized = proto.SerializeToString()
2405    # Shouldn't raise exceptions.
2406    partial = proto.SerializePartialToString()
2407
2408    proto2 = unittest_pb2.TestRequired()
2409    proto2.MergeFromString(serialized)
2410    self.assertEqual(1, proto2.a)
2411    self.assertEqual(2, proto2.b)
2412    self.assertEqual(3, proto2.c)
2413    proto2.ParseFromString(partial)
2414    self.assertEqual(1, proto2.a)
2415    self.assertEqual(2, proto2.b)
2416    self.assertEqual(3, proto2.c)
2417
2418  def testSerializeUninitializedSubMessage(self):
2419    proto = unittest_pb2.TestRequiredForeign()
2420
2421    # Sub-message doesn't exist yet, so this succeeds.
2422    proto.SerializeToString()
2423
2424    proto.optional_message.a = 1
2425    self._CheckRaises(
2426        message.EncodeError,
2427        proto.SerializeToString,
2428        'Message protobuf_unittest.TestRequiredForeign '
2429        'is missing required fields: '
2430        'optional_message.b,optional_message.c')
2431
2432    proto.optional_message.b = 2
2433    proto.optional_message.c = 3
2434    proto.SerializeToString()
2435
2436    proto.repeated_message.add().a = 1
2437    proto.repeated_message.add().b = 2
2438    self._CheckRaises(
2439        message.EncodeError,
2440        proto.SerializeToString,
2441        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2442        'repeated_message[0].b,repeated_message[0].c,'
2443        'repeated_message[1].a,repeated_message[1].c')
2444
2445    proto.repeated_message[0].b = 2
2446    proto.repeated_message[0].c = 3
2447    proto.repeated_message[1].a = 1
2448    proto.repeated_message[1].c = 3
2449    proto.SerializeToString()
2450
2451  def testSerializeAllPackedFields(self):
2452    first_proto = unittest_pb2.TestPackedTypes()
2453    second_proto = unittest_pb2.TestPackedTypes()
2454    test_util.SetAllPackedFields(first_proto)
2455    serialized = first_proto.SerializeToString()
2456    self.assertEqual(first_proto.ByteSize(), len(serialized))
2457    bytes_read = second_proto.MergeFromString(serialized)
2458    self.assertEqual(second_proto.ByteSize(), bytes_read)
2459    self.assertEqual(first_proto, second_proto)
2460
2461  def testSerializeAllPackedExtensions(self):
2462    first_proto = unittest_pb2.TestPackedExtensions()
2463    second_proto = unittest_pb2.TestPackedExtensions()
2464    test_util.SetAllPackedExtensions(first_proto)
2465    serialized = first_proto.SerializeToString()
2466    bytes_read = second_proto.MergeFromString(serialized)
2467    self.assertEqual(second_proto.ByteSize(), bytes_read)
2468    self.assertEqual(first_proto, second_proto)
2469
2470  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2471    first_proto = unittest_pb2.TestPackedTypes()
2472    first_proto.packed_int32.extend([1, 2])
2473    first_proto.packed_double.append(3.0)
2474    serialized = first_proto.SerializeToString()
2475
2476    second_proto = unittest_pb2.TestPackedTypes()
2477    second_proto.packed_int32.append(3)
2478    second_proto.packed_double.extend([1.0, 2.0])
2479    second_proto.packed_sint32.append(4)
2480
2481    second_proto.MergeFromString(serialized)
2482    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2483    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2484    self.assertEqual([4], second_proto.packed_sint32)
2485
2486  def testPackedFieldsWireFormat(self):
2487    proto = unittest_pb2.TestPackedTypes()
2488    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2489    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2490    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2491    serialized = proto.SerializeToString()
2492    self.assertEqual(proto.ByteSize(), len(serialized))
2493    d = _MiniDecoder(serialized)
2494    ReadTag = d.ReadFieldNumberAndWireType
2495    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2496    self.assertEqual(1+1+1+2, d.ReadInt32())
2497    self.assertEqual(1, d.ReadInt32())
2498    self.assertEqual(2, d.ReadInt32())
2499    self.assertEqual(150, d.ReadInt32())
2500    self.assertEqual(3, d.ReadInt32())
2501    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2502    self.assertEqual(4, d.ReadInt32())
2503    self.assertEqual(2.0, d.ReadFloat())
2504    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2505    self.assertEqual(8+8, d.ReadInt32())
2506    self.assertEqual(1.0, d.ReadDouble())
2507    self.assertEqual(1000.0, d.ReadDouble())
2508    self.assertTrue(d.EndOfStream())
2509
2510  def testParsePackedFromUnpacked(self):
2511    unpacked = unittest_pb2.TestUnpackedTypes()
2512    test_util.SetAllUnpackedFields(unpacked)
2513    packed = unittest_pb2.TestPackedTypes()
2514    packed.MergeFromString(unpacked.SerializeToString())
2515    expected = unittest_pb2.TestPackedTypes()
2516    test_util.SetAllPackedFields(expected)
2517    self.assertEqual(expected, packed)
2518
2519  def testParseUnpackedFromPacked(self):
2520    packed = unittest_pb2.TestPackedTypes()
2521    test_util.SetAllPackedFields(packed)
2522    unpacked = unittest_pb2.TestUnpackedTypes()
2523    unpacked.MergeFromString(packed.SerializeToString())
2524    expected = unittest_pb2.TestUnpackedTypes()
2525    test_util.SetAllUnpackedFields(expected)
2526    self.assertEqual(expected, unpacked)
2527
2528  def testFieldNumbers(self):
2529    proto = unittest_pb2.TestAllTypes()
2530    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2531    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2532    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
2533    self.assertEqual(
2534      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
2535    self.assertEqual(
2536      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
2537    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
2538    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
2539    self.assertEqual(
2540      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
2541    self.assertEqual(
2542      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
2543
2544  def testExtensionFieldNumbers(self):
2545    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
2546    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
2547    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
2548    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
2549    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
2550    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
2551    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
2552    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
2553    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
2554    self.assertEqual(
2555      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
2556    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
2557    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2558      21)
2559    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
2560    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
2561    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
2562    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
2563    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
2564    self.assertEqual(
2565      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
2566    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
2567    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2568      51)
2569
2570  def testInitKwargs(self):
2571    proto = unittest_pb2.TestAllTypes(
2572        optional_int32=1,
2573        optional_string='foo',
2574        optional_bool=True,
2575        optional_bytes='bar',
2576        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
2577        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
2578        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
2579        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
2580        repeated_int32=[1, 2, 3])
2581    self.assertTrue(proto.IsInitialized())
2582    self.assertTrue(proto.HasField('optional_int32'))
2583    self.assertTrue(proto.HasField('optional_string'))
2584    self.assertTrue(proto.HasField('optional_bool'))
2585    self.assertTrue(proto.HasField('optional_bytes'))
2586    self.assertTrue(proto.HasField('optional_nested_message'))
2587    self.assertTrue(proto.HasField('optional_foreign_message'))
2588    self.assertTrue(proto.HasField('optional_nested_enum'))
2589    self.assertTrue(proto.HasField('optional_foreign_enum'))
2590    self.assertEqual(1, proto.optional_int32)
2591    self.assertEqual('foo', proto.optional_string)
2592    self.assertEqual(True, proto.optional_bool)
2593    self.assertEqual('bar', proto.optional_bytes)
2594    self.assertEqual(1, proto.optional_nested_message.bb)
2595    self.assertEqual(1, proto.optional_foreign_message.c)
2596    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
2597                     proto.optional_nested_enum)
2598    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
2599    self.assertEqual([1, 2, 3], proto.repeated_int32)
2600
2601  def testInitArgsUnknownFieldName(self):
2602    def InitalizeEmptyMessageWithExtraKeywordArg():
2603      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
2604    self._CheckRaises(ValueError,
2605                      InitalizeEmptyMessageWithExtraKeywordArg,
2606                      'Protocol message has no "unknown" field.')
2607
2608  def testInitRequiredKwargs(self):
2609    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
2610    self.assertTrue(proto.IsInitialized())
2611    self.assertTrue(proto.HasField('a'))
2612    self.assertTrue(proto.HasField('b'))
2613    self.assertTrue(proto.HasField('c'))
2614    self.assertTrue(not proto.HasField('dummy2'))
2615    self.assertEqual(1, proto.a)
2616    self.assertEqual(1, proto.b)
2617    self.assertEqual(1, proto.c)
2618
2619  def testInitRequiredForeignKwargs(self):
2620    proto = unittest_pb2.TestRequiredForeign(
2621        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
2622    self.assertTrue(proto.IsInitialized())
2623    self.assertTrue(proto.HasField('optional_message'))
2624    self.assertTrue(proto.optional_message.IsInitialized())
2625    self.assertTrue(proto.optional_message.HasField('a'))
2626    self.assertTrue(proto.optional_message.HasField('b'))
2627    self.assertTrue(proto.optional_message.HasField('c'))
2628    self.assertTrue(not proto.optional_message.HasField('dummy2'))
2629    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
2630                     proto.optional_message)
2631    self.assertEqual(1, proto.optional_message.a)
2632    self.assertEqual(1, proto.optional_message.b)
2633    self.assertEqual(1, proto.optional_message.c)
2634
2635  def testInitRepeatedKwargs(self):
2636    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
2637    self.assertTrue(proto.IsInitialized())
2638    self.assertEqual(1, proto.repeated_int32[0])
2639    self.assertEqual(2, proto.repeated_int32[1])
2640    self.assertEqual(3, proto.repeated_int32[2])
2641
2642
2643class OptionsTest(unittest.TestCase):
2644
2645  def testMessageOptions(self):
2646    proto = unittest_mset_pb2.TestMessageSet()
2647    self.assertEqual(True,
2648                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2649    proto = unittest_pb2.TestAllTypes()
2650    self.assertEqual(False,
2651                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2652
2653  def testPackedOptions(self):
2654    proto = unittest_pb2.TestAllTypes()
2655    proto.optional_int32 = 1
2656    proto.optional_double = 3.0
2657    for field_descriptor, _ in proto.ListFields():
2658      self.assertEqual(False, field_descriptor.GetOptions().packed)
2659
2660    proto = unittest_pb2.TestPackedTypes()
2661    proto.packed_int32.append(1)
2662    proto.packed_double.append(3.0)
2663    for field_descriptor, _ in proto.ListFields():
2664      self.assertEqual(True, field_descriptor.GetOptions().packed)
2665      self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
2666                       field_descriptor.label)
2667
2668
2669
2670if __name__ == '__main__':
2671  unittest.main()
2672