1#! /usr/bin/python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# http://code.google.com/p/protobuf/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Unittest for reflection.py, which also indirectly tests the output of the 35pure-Python protocol compiler. 36""" 37 38__author__ = 'robinson@google.com (Will Robinson)' 39 40import operator 41import struct 42 43import unittest 44# TODO(robinson): When we split this test in two, only some of these imports 45# will be necessary in each test. 46from google.protobuf import unittest_import_pb2 47from google.protobuf import unittest_mset_pb2 48from google.protobuf import unittest_pb2 49from google.protobuf import descriptor_pb2 50from google.protobuf import descriptor 51from google.protobuf import message 52from google.protobuf import reflection 53from google.protobuf.internal import more_extensions_pb2 54from google.protobuf.internal import more_messages_pb2 55from google.protobuf.internal import wire_format 56from google.protobuf.internal import test_util 57from google.protobuf.internal import decoder 58 59 60class _MiniDecoder(object): 61 """Decodes a stream of values from a string. 62 63 Once upon a time we actually had a class called decoder.Decoder. Then we 64 got rid of it during a redesign that made decoding much, much faster overall. 65 But a couple tests in this file used it to check that the serialized form of 66 a message was correct. So, this class implements just the methods that were 67 used by said tests, so that we don't have to rewrite the tests. 68 """ 69 70 def __init__(self, bytes): 71 self._bytes = bytes 72 self._pos = 0 73 74 def ReadVarint(self): 75 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 76 return result 77 78 ReadInt32 = ReadVarint 79 ReadInt64 = ReadVarint 80 ReadUInt32 = ReadVarint 81 ReadUInt64 = ReadVarint 82 83 def ReadSInt64(self): 84 return wire_format.ZigZagDecode(self.ReadVarint()) 85 86 ReadSInt32 = ReadSInt64 87 88 def ReadFieldNumberAndWireType(self): 89 return wire_format.UnpackTag(self.ReadVarint()) 90 91 def ReadFloat(self): 92 result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] 93 self._pos += 4 94 return result 95 96 def ReadDouble(self): 97 result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] 98 self._pos += 8 99 return result 100 101 def EndOfStream(self): 102 return self._pos == len(self._bytes) 103 104 105class ReflectionTest(unittest.TestCase): 106 107 def assertIs(self, values, others): 108 self.assertEqual(len(values), len(others)) 109 for i in range(len(values)): 110 self.assertTrue(values[i] is others[i]) 111 112 def testScalarConstructor(self): 113 # Constructor with only scalar types should succeed. 114 proto = unittest_pb2.TestAllTypes( 115 optional_int32=24, 116 optional_double=54.321, 117 optional_string='optional_string') 118 119 self.assertEqual(24, proto.optional_int32) 120 self.assertEqual(54.321, proto.optional_double) 121 self.assertEqual('optional_string', proto.optional_string) 122 123 def testRepeatedScalarConstructor(self): 124 # Constructor with only repeated scalar types should succeed. 125 proto = unittest_pb2.TestAllTypes( 126 repeated_int32=[1, 2, 3, 4], 127 repeated_double=[1.23, 54.321], 128 repeated_bool=[True, False, False], 129 repeated_string=["optional_string"]) 130 131 self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) 132 self.assertEquals([1.23, 54.321], list(proto.repeated_double)) 133 self.assertEquals([True, False, False], list(proto.repeated_bool)) 134 self.assertEquals(["optional_string"], list(proto.repeated_string)) 135 136 def testRepeatedCompositeConstructor(self): 137 # Constructor with only repeated composite types should succeed. 138 proto = unittest_pb2.TestAllTypes( 139 repeated_nested_message=[ 140 unittest_pb2.TestAllTypes.NestedMessage( 141 bb=unittest_pb2.TestAllTypes.FOO), 142 unittest_pb2.TestAllTypes.NestedMessage( 143 bb=unittest_pb2.TestAllTypes.BAR)], 144 repeated_foreign_message=[ 145 unittest_pb2.ForeignMessage(c=-43), 146 unittest_pb2.ForeignMessage(c=45324), 147 unittest_pb2.ForeignMessage(c=12)], 148 repeatedgroup=[ 149 unittest_pb2.TestAllTypes.RepeatedGroup(), 150 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 151 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 152 153 self.assertEquals( 154 [unittest_pb2.TestAllTypes.NestedMessage( 155 bb=unittest_pb2.TestAllTypes.FOO), 156 unittest_pb2.TestAllTypes.NestedMessage( 157 bb=unittest_pb2.TestAllTypes.BAR)], 158 list(proto.repeated_nested_message)) 159 self.assertEquals( 160 [unittest_pb2.ForeignMessage(c=-43), 161 unittest_pb2.ForeignMessage(c=45324), 162 unittest_pb2.ForeignMessage(c=12)], 163 list(proto.repeated_foreign_message)) 164 self.assertEquals( 165 [unittest_pb2.TestAllTypes.RepeatedGroup(), 166 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 167 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 168 list(proto.repeatedgroup)) 169 170 def testMixedConstructor(self): 171 # Constructor with only mixed types should succeed. 172 proto = unittest_pb2.TestAllTypes( 173 optional_int32=24, 174 optional_string='optional_string', 175 repeated_double=[1.23, 54.321], 176 repeated_bool=[True, False, False], 177 repeated_nested_message=[ 178 unittest_pb2.TestAllTypes.NestedMessage( 179 bb=unittest_pb2.TestAllTypes.FOO), 180 unittest_pb2.TestAllTypes.NestedMessage( 181 bb=unittest_pb2.TestAllTypes.BAR)], 182 repeated_foreign_message=[ 183 unittest_pb2.ForeignMessage(c=-43), 184 unittest_pb2.ForeignMessage(c=45324), 185 unittest_pb2.ForeignMessage(c=12)]) 186 187 self.assertEqual(24, proto.optional_int32) 188 self.assertEqual('optional_string', proto.optional_string) 189 self.assertEquals([1.23, 54.321], list(proto.repeated_double)) 190 self.assertEquals([True, False, False], list(proto.repeated_bool)) 191 self.assertEquals( 192 [unittest_pb2.TestAllTypes.NestedMessage( 193 bb=unittest_pb2.TestAllTypes.FOO), 194 unittest_pb2.TestAllTypes.NestedMessage( 195 bb=unittest_pb2.TestAllTypes.BAR)], 196 list(proto.repeated_nested_message)) 197 self.assertEquals( 198 [unittest_pb2.ForeignMessage(c=-43), 199 unittest_pb2.ForeignMessage(c=45324), 200 unittest_pb2.ForeignMessage(c=12)], 201 list(proto.repeated_foreign_message)) 202 203 def testSimpleHasBits(self): 204 # Test a scalar. 205 proto = unittest_pb2.TestAllTypes() 206 self.assertTrue(not proto.HasField('optional_int32')) 207 self.assertEqual(0, proto.optional_int32) 208 # HasField() shouldn't be true if all we've done is 209 # read the default value. 210 self.assertTrue(not proto.HasField('optional_int32')) 211 proto.optional_int32 = 1 212 # Setting a value however *should* set the "has" bit. 213 self.assertTrue(proto.HasField('optional_int32')) 214 proto.ClearField('optional_int32') 215 # And clearing that value should unset the "has" bit. 216 self.assertTrue(not proto.HasField('optional_int32')) 217 218 def testHasBitsWithSinglyNestedScalar(self): 219 # Helper used to test foreign messages and groups. 220 # 221 # composite_field_name should be the name of a non-repeated 222 # composite (i.e., foreign or group) field in TestAllTypes, 223 # and scalar_field_name should be the name of an integer-valued 224 # scalar field within that composite. 225 # 226 # I never thought I'd miss C++ macros and templates so much. :( 227 # This helper is semantically just: 228 # 229 # assert proto.composite_field.scalar_field == 0 230 # assert not proto.composite_field.HasField('scalar_field') 231 # assert not proto.HasField('composite_field') 232 # 233 # proto.composite_field.scalar_field = 10 234 # old_composite_field = proto.composite_field 235 # 236 # assert proto.composite_field.scalar_field == 10 237 # assert proto.composite_field.HasField('scalar_field') 238 # assert proto.HasField('composite_field') 239 # 240 # proto.ClearField('composite_field') 241 # 242 # assert not proto.composite_field.HasField('scalar_field') 243 # assert not proto.HasField('composite_field') 244 # assert proto.composite_field.scalar_field == 0 245 # 246 # # Now ensure that ClearField('composite_field') disconnected 247 # # the old field object from the object tree... 248 # assert old_composite_field is not proto.composite_field 249 # old_composite_field.scalar_field = 20 250 # assert not proto.composite_field.HasField('scalar_field') 251 # assert not proto.HasField('composite_field') 252 def TestCompositeHasBits(composite_field_name, scalar_field_name): 253 proto = unittest_pb2.TestAllTypes() 254 # First, check that we can get the scalar value, and see that it's the 255 # default (0), but that proto.HasField('omposite') and 256 # proto.composite.HasField('scalar') will still return False. 257 composite_field = getattr(proto, composite_field_name) 258 original_scalar_value = getattr(composite_field, scalar_field_name) 259 self.assertEqual(0, original_scalar_value) 260 # Assert that the composite object does not "have" the scalar. 261 self.assertTrue(not composite_field.HasField(scalar_field_name)) 262 # Assert that proto does not "have" the composite field. 263 self.assertTrue(not proto.HasField(composite_field_name)) 264 265 # Now set the scalar within the composite field. Ensure that the setting 266 # is reflected, and that proto.HasField('composite') and 267 # proto.composite.HasField('scalar') now both return True. 268 new_val = 20 269 setattr(composite_field, scalar_field_name, new_val) 270 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 271 # Hold on to a reference to the current composite_field object. 272 old_composite_field = composite_field 273 # Assert that the has methods now return true. 274 self.assertTrue(composite_field.HasField(scalar_field_name)) 275 self.assertTrue(proto.HasField(composite_field_name)) 276 277 # Now call the clear method... 278 proto.ClearField(composite_field_name) 279 280 # ...and ensure that the "has" bits are all back to False... 281 composite_field = getattr(proto, composite_field_name) 282 self.assertTrue(not composite_field.HasField(scalar_field_name)) 283 self.assertTrue(not proto.HasField(composite_field_name)) 284 # ...and ensure that the scalar field has returned to its default. 285 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 286 287 # Finally, ensure that modifications to the old composite field object 288 # don't have any effect on the parent. 289 # 290 # (NOTE that when we clear the composite field in the parent, we actually 291 # don't recursively clear down the tree. Instead, we just disconnect the 292 # cleared composite from the tree.) 293 self.assertTrue(old_composite_field is not composite_field) 294 setattr(old_composite_field, scalar_field_name, new_val) 295 self.assertTrue(not composite_field.HasField(scalar_field_name)) 296 self.assertTrue(not proto.HasField(composite_field_name)) 297 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 298 299 # Test simple, single-level nesting when we set a scalar. 300 TestCompositeHasBits('optionalgroup', 'a') 301 TestCompositeHasBits('optional_nested_message', 'bb') 302 TestCompositeHasBits('optional_foreign_message', 'c') 303 TestCompositeHasBits('optional_import_message', 'd') 304 305 def testReferencesToNestedMessage(self): 306 proto = unittest_pb2.TestAllTypes() 307 nested = proto.optional_nested_message 308 del proto 309 # A previous version had a bug where this would raise an exception when 310 # hitting a now-dead weak reference. 311 nested.bb = 23 312 313 def testDisconnectingNestedMessageBeforeSettingField(self): 314 proto = unittest_pb2.TestAllTypes() 315 nested = proto.optional_nested_message 316 proto.ClearField('optional_nested_message') # Should disconnect from parent 317 self.assertTrue(nested is not proto.optional_nested_message) 318 nested.bb = 23 319 self.assertTrue(not proto.HasField('optional_nested_message')) 320 self.assertEqual(0, proto.optional_nested_message.bb) 321 322 def testHasBitsWhenModifyingRepeatedFields(self): 323 # Test nesting when we add an element to a repeated field in a submessage. 324 proto = unittest_pb2.TestNestedMessageHasBits() 325 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 326 self.assertEqual( 327 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 328 self.assertTrue(proto.HasField('optional_nested_message')) 329 330 # Do the same test, but with a repeated composite field within the 331 # submessage. 332 proto.ClearField('optional_nested_message') 333 self.assertTrue(not proto.HasField('optional_nested_message')) 334 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 335 self.assertTrue(proto.HasField('optional_nested_message')) 336 337 def testHasBitsForManyLevelsOfNesting(self): 338 # Test nesting many levels deep. 339 recursive_proto = unittest_pb2.TestMutualRecursionA() 340 self.assertTrue(not recursive_proto.HasField('bb')) 341 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 342 self.assertTrue(not recursive_proto.HasField('bb')) 343 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 344 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 345 self.assertTrue(recursive_proto.HasField('bb')) 346 self.assertTrue(recursive_proto.bb.HasField('a')) 347 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 348 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 349 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 350 self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) 351 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 352 353 def testSingularListFields(self): 354 proto = unittest_pb2.TestAllTypes() 355 proto.optional_fixed32 = 1 356 proto.optional_int32 = 5 357 proto.optional_string = 'foo' 358 # Access sub-message but don't set it yet. 359 nested_message = proto.optional_nested_message 360 self.assertEqual( 361 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 362 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 363 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 364 proto.ListFields()) 365 366 proto.optional_nested_message.bb = 123 367 self.assertEqual( 368 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 369 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 370 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 371 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 372 nested_message) ], 373 proto.ListFields()) 374 375 def testRepeatedListFields(self): 376 proto = unittest_pb2.TestAllTypes() 377 proto.repeated_fixed32.append(1) 378 proto.repeated_int32.append(5) 379 proto.repeated_int32.append(11) 380 proto.repeated_string.extend(['foo', 'bar']) 381 proto.repeated_string.extend([]) 382 proto.repeated_string.append('baz') 383 proto.repeated_string.extend(str(x) for x in xrange(2)) 384 proto.optional_int32 = 21 385 proto.repeated_bool # Access but don't set anything; should not be listed. 386 self.assertEqual( 387 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 388 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 389 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 390 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 391 ['foo', 'bar', 'baz', '0', '1']) ], 392 proto.ListFields()) 393 394 def testSingularListExtensions(self): 395 proto = unittest_pb2.TestAllExtensions() 396 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 397 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 398 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 399 self.assertEqual( 400 [ (unittest_pb2.optional_int32_extension , 5), 401 (unittest_pb2.optional_fixed32_extension, 1), 402 (unittest_pb2.optional_string_extension , 'foo') ], 403 proto.ListFields()) 404 405 def testRepeatedListExtensions(self): 406 proto = unittest_pb2.TestAllExtensions() 407 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 408 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 409 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 410 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 411 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 412 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 413 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 414 self.assertEqual( 415 [ (unittest_pb2.optional_int32_extension , 21), 416 (unittest_pb2.repeated_int32_extension , [5, 11]), 417 (unittest_pb2.repeated_fixed32_extension, [1]), 418 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 419 proto.ListFields()) 420 421 def testListFieldsAndExtensions(self): 422 proto = unittest_pb2.TestFieldOrderings() 423 test_util.SetAllFieldsAndExtensions(proto) 424 unittest_pb2.my_extension_int 425 self.assertEqual( 426 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 427 (unittest_pb2.my_extension_int , 23), 428 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 429 (unittest_pb2.my_extension_string , 'bar'), 430 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 431 proto.ListFields()) 432 433 def testDefaultValues(self): 434 proto = unittest_pb2.TestAllTypes() 435 self.assertEqual(0, proto.optional_int32) 436 self.assertEqual(0, proto.optional_int64) 437 self.assertEqual(0, proto.optional_uint32) 438 self.assertEqual(0, proto.optional_uint64) 439 self.assertEqual(0, proto.optional_sint32) 440 self.assertEqual(0, proto.optional_sint64) 441 self.assertEqual(0, proto.optional_fixed32) 442 self.assertEqual(0, proto.optional_fixed64) 443 self.assertEqual(0, proto.optional_sfixed32) 444 self.assertEqual(0, proto.optional_sfixed64) 445 self.assertEqual(0.0, proto.optional_float) 446 self.assertEqual(0.0, proto.optional_double) 447 self.assertEqual(False, proto.optional_bool) 448 self.assertEqual('', proto.optional_string) 449 self.assertEqual('', proto.optional_bytes) 450 451 self.assertEqual(41, proto.default_int32) 452 self.assertEqual(42, proto.default_int64) 453 self.assertEqual(43, proto.default_uint32) 454 self.assertEqual(44, proto.default_uint64) 455 self.assertEqual(-45, proto.default_sint32) 456 self.assertEqual(46, proto.default_sint64) 457 self.assertEqual(47, proto.default_fixed32) 458 self.assertEqual(48, proto.default_fixed64) 459 self.assertEqual(49, proto.default_sfixed32) 460 self.assertEqual(-50, proto.default_sfixed64) 461 self.assertEqual(51.5, proto.default_float) 462 self.assertEqual(52e3, proto.default_double) 463 self.assertEqual(True, proto.default_bool) 464 self.assertEqual('hello', proto.default_string) 465 self.assertEqual('world', proto.default_bytes) 466 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 467 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 468 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 469 proto.default_import_enum) 470 471 proto = unittest_pb2.TestExtremeDefaultValues() 472 self.assertEqual(u'\u1234', proto.utf8_string) 473 474 def testHasFieldWithUnknownFieldName(self): 475 proto = unittest_pb2.TestAllTypes() 476 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 477 478 def testClearFieldWithUnknownFieldName(self): 479 proto = unittest_pb2.TestAllTypes() 480 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 481 482 def testDisallowedAssignments(self): 483 # It's illegal to assign values directly to repeated fields 484 # or to nonrepeated composite fields. Ensure that this fails. 485 proto = unittest_pb2.TestAllTypes() 486 # Repeated fields. 487 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 488 # Lists shouldn't work, either. 489 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 490 # Composite fields. 491 self.assertRaises(AttributeError, setattr, proto, 492 'optional_nested_message', 23) 493 # Assignment to a repeated nested message field without specifying 494 # the index in the array of nested messages. 495 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 496 'bb', 34) 497 # Assignment to an attribute of a repeated field. 498 self.assertRaises(AttributeError, setattr, proto.repeated_float, 499 'some_attribute', 34) 500 # proto.nonexistent_field = 23 should fail as well. 501 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 502 503 # TODO(robinson): Add type-safety check for enums. 504 def testSingleScalarTypeSafety(self): 505 proto = unittest_pb2.TestAllTypes() 506 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 507 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 508 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 509 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 510 511 def testSingleScalarBoundsChecking(self): 512 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 513 pb = unittest_pb2.TestAllTypes() 514 setattr(pb, field_name, expected_min) 515 setattr(pb, field_name, expected_max) 516 self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) 517 self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) 518 519 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 520 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 521 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 522 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 523 TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) 524 525 def testRepeatedScalarTypeSafety(self): 526 proto = unittest_pb2.TestAllTypes() 527 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 528 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 529 self.assertRaises(TypeError, proto.repeated_string, 10) 530 self.assertRaises(TypeError, proto.repeated_bytes, 10) 531 532 proto.repeated_int32.append(10) 533 proto.repeated_int32[0] = 23 534 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 535 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 536 537 def testSingleScalarGettersAndSetters(self): 538 proto = unittest_pb2.TestAllTypes() 539 self.assertEqual(0, proto.optional_int32) 540 proto.optional_int32 = 1 541 self.assertEqual(1, proto.optional_int32) 542 # TODO(robinson): Test all other scalar field types. 543 544 def testSingleScalarClearField(self): 545 proto = unittest_pb2.TestAllTypes() 546 # Should be allowed to clear something that's not there (a no-op). 547 proto.ClearField('optional_int32') 548 proto.optional_int32 = 1 549 self.assertTrue(proto.HasField('optional_int32')) 550 proto.ClearField('optional_int32') 551 self.assertEqual(0, proto.optional_int32) 552 self.assertTrue(not proto.HasField('optional_int32')) 553 # TODO(robinson): Test all other scalar field types. 554 555 def testEnums(self): 556 proto = unittest_pb2.TestAllTypes() 557 self.assertEqual(1, proto.FOO) 558 self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) 559 self.assertEqual(2, proto.BAR) 560 self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) 561 self.assertEqual(3, proto.BAZ) 562 self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) 563 564 def testRepeatedScalars(self): 565 proto = unittest_pb2.TestAllTypes() 566 567 self.assertTrue(not proto.repeated_int32) 568 self.assertEqual(0, len(proto.repeated_int32)) 569 proto.repeated_int32.append(5) 570 proto.repeated_int32.append(10) 571 proto.repeated_int32.append(15) 572 self.assertTrue(proto.repeated_int32) 573 self.assertEqual(3, len(proto.repeated_int32)) 574 575 self.assertEqual([5, 10, 15], proto.repeated_int32) 576 577 # Test single retrieval. 578 self.assertEqual(5, proto.repeated_int32[0]) 579 self.assertEqual(15, proto.repeated_int32[-1]) 580 # Test out-of-bounds indices. 581 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 582 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 583 # Test incorrect types passed to __getitem__. 584 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 585 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 586 587 # Test single assignment. 588 proto.repeated_int32[1] = 20 589 self.assertEqual([5, 20, 15], proto.repeated_int32) 590 591 # Test insertion. 592 proto.repeated_int32.insert(1, 25) 593 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 594 595 # Test slice retrieval. 596 proto.repeated_int32.append(30) 597 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 598 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 599 600 # Test slice assignment with an iterator 601 proto.repeated_int32[1:4] = (i for i in xrange(3)) 602 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 603 604 # Test slice assignment. 605 proto.repeated_int32[1:4] = [35, 40, 45] 606 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 607 608 # Test that we can use the field as an iterator. 609 result = [] 610 for i in proto.repeated_int32: 611 result.append(i) 612 self.assertEqual([5, 35, 40, 45, 30], result) 613 614 # Test single deletion. 615 del proto.repeated_int32[2] 616 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 617 618 # Test slice deletion. 619 del proto.repeated_int32[2:] 620 self.assertEqual([5, 35], proto.repeated_int32) 621 622 # Test clearing. 623 proto.ClearField('repeated_int32') 624 self.assertTrue(not proto.repeated_int32) 625 self.assertEqual(0, len(proto.repeated_int32)) 626 627 def testRepeatedScalarsRemove(self): 628 proto = unittest_pb2.TestAllTypes() 629 630 self.assertTrue(not proto.repeated_int32) 631 self.assertEqual(0, len(proto.repeated_int32)) 632 proto.repeated_int32.append(5) 633 proto.repeated_int32.append(10) 634 proto.repeated_int32.append(5) 635 proto.repeated_int32.append(5) 636 637 self.assertEqual(4, len(proto.repeated_int32)) 638 proto.repeated_int32.remove(5) 639 self.assertEqual(3, len(proto.repeated_int32)) 640 self.assertEqual(10, proto.repeated_int32[0]) 641 self.assertEqual(5, proto.repeated_int32[1]) 642 self.assertEqual(5, proto.repeated_int32[2]) 643 644 proto.repeated_int32.remove(5) 645 self.assertEqual(2, len(proto.repeated_int32)) 646 self.assertEqual(10, proto.repeated_int32[0]) 647 self.assertEqual(5, proto.repeated_int32[1]) 648 649 proto.repeated_int32.remove(10) 650 self.assertEqual(1, len(proto.repeated_int32)) 651 self.assertEqual(5, proto.repeated_int32[0]) 652 653 # Remove a non-existent element. 654 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 655 656 def testRepeatedComposites(self): 657 proto = unittest_pb2.TestAllTypes() 658 self.assertTrue(not proto.repeated_nested_message) 659 self.assertEqual(0, len(proto.repeated_nested_message)) 660 m0 = proto.repeated_nested_message.add() 661 m1 = proto.repeated_nested_message.add() 662 self.assertTrue(proto.repeated_nested_message) 663 self.assertEqual(2, len(proto.repeated_nested_message)) 664 self.assertIs([m0, m1], proto.repeated_nested_message) 665 self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) 666 667 # Test out-of-bounds indices. 668 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 669 1234) 670 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 671 -1234) 672 673 # Test incorrect types passed to __getitem__. 674 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 675 'foo') 676 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 677 None) 678 679 # Test slice retrieval. 680 m2 = proto.repeated_nested_message.add() 681 m3 = proto.repeated_nested_message.add() 682 m4 = proto.repeated_nested_message.add() 683 self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4]) 684 self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 685 686 # Test that we can use the field as an iterator. 687 result = [] 688 for i in proto.repeated_nested_message: 689 result.append(i) 690 self.assertIs([m0, m1, m2, m3, m4], result) 691 692 # Test single deletion. 693 del proto.repeated_nested_message[2] 694 self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message) 695 696 # Test slice deletion. 697 del proto.repeated_nested_message[2:] 698 self.assertIs([m0, m1], proto.repeated_nested_message) 699 700 # Test clearing. 701 proto.ClearField('repeated_nested_message') 702 self.assertTrue(not proto.repeated_nested_message) 703 self.assertEqual(0, len(proto.repeated_nested_message)) 704 705 def testHandWrittenReflection(self): 706 # TODO(robinson): We probably need a better way to specify 707 # protocol types by hand. But then again, this isn't something 708 # we expect many people to do. Hmm. 709 FieldDescriptor = descriptor.FieldDescriptor 710 foo_field_descriptor = FieldDescriptor( 711 name='foo_field', full_name='MyProto.foo_field', 712 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 713 cpp_type=FieldDescriptor.CPPTYPE_INT64, 714 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 715 containing_type=None, message_type=None, enum_type=None, 716 is_extension=False, extension_scope=None, 717 options=descriptor_pb2.FieldOptions()) 718 mydescriptor = descriptor.Descriptor( 719 name='MyProto', full_name='MyProto', filename='ignored', 720 containing_type=None, nested_types=[], enum_types=[], 721 fields=[foo_field_descriptor], extensions=[], 722 options=descriptor_pb2.MessageOptions()) 723 class MyProtoClass(message.Message): 724 DESCRIPTOR = mydescriptor 725 __metaclass__ = reflection.GeneratedProtocolMessageType 726 myproto_instance = MyProtoClass() 727 self.assertEqual(0, myproto_instance.foo_field) 728 self.assertTrue(not myproto_instance.HasField('foo_field')) 729 myproto_instance.foo_field = 23 730 self.assertEqual(23, myproto_instance.foo_field) 731 self.assertTrue(myproto_instance.HasField('foo_field')) 732 733 def testTopLevelExtensionsForOptionalScalar(self): 734 extendee_proto = unittest_pb2.TestAllExtensions() 735 extension = unittest_pb2.optional_int32_extension 736 self.assertTrue(not extendee_proto.HasExtension(extension)) 737 self.assertEqual(0, extendee_proto.Extensions[extension]) 738 # As with normal scalar fields, just doing a read doesn't actually set the 739 # "has" bit. 740 self.assertTrue(not extendee_proto.HasExtension(extension)) 741 # Actually set the thing. 742 extendee_proto.Extensions[extension] = 23 743 self.assertEqual(23, extendee_proto.Extensions[extension]) 744 self.assertTrue(extendee_proto.HasExtension(extension)) 745 # Ensure that clearing works as well. 746 extendee_proto.ClearExtension(extension) 747 self.assertEqual(0, extendee_proto.Extensions[extension]) 748 self.assertTrue(not extendee_proto.HasExtension(extension)) 749 750 def testTopLevelExtensionsForRepeatedScalar(self): 751 extendee_proto = unittest_pb2.TestAllExtensions() 752 extension = unittest_pb2.repeated_string_extension 753 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 754 extendee_proto.Extensions[extension].append('foo') 755 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 756 string_list = extendee_proto.Extensions[extension] 757 extendee_proto.ClearExtension(extension) 758 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 759 self.assertTrue(string_list is not extendee_proto.Extensions[extension]) 760 # Shouldn't be allowed to do Extensions[extension] = 'a' 761 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 762 extension, 'a') 763 764 def testTopLevelExtensionsForOptionalMessage(self): 765 extendee_proto = unittest_pb2.TestAllExtensions() 766 extension = unittest_pb2.optional_foreign_message_extension 767 self.assertTrue(not extendee_proto.HasExtension(extension)) 768 self.assertEqual(0, extendee_proto.Extensions[extension].c) 769 # As with normal (non-extension) fields, merely reading from the 770 # thing shouldn't set the "has" bit. 771 self.assertTrue(not extendee_proto.HasExtension(extension)) 772 extendee_proto.Extensions[extension].c = 23 773 self.assertEqual(23, extendee_proto.Extensions[extension].c) 774 self.assertTrue(extendee_proto.HasExtension(extension)) 775 # Save a reference here. 776 foreign_message = extendee_proto.Extensions[extension] 777 extendee_proto.ClearExtension(extension) 778 self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) 779 # Setting a field on foreign_message now shouldn't set 780 # any "has" bits on extendee_proto. 781 foreign_message.c = 42 782 self.assertEqual(42, foreign_message.c) 783 self.assertTrue(foreign_message.HasField('c')) 784 self.assertTrue(not extendee_proto.HasExtension(extension)) 785 # Shouldn't be allowed to do Extensions[extension] = 'a' 786 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 787 extension, 'a') 788 789 def testTopLevelExtensionsForRepeatedMessage(self): 790 extendee_proto = unittest_pb2.TestAllExtensions() 791 extension = unittest_pb2.repeatedgroup_extension 792 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 793 group = extendee_proto.Extensions[extension].add() 794 group.a = 23 795 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 796 group.a = 42 797 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 798 group_list = extendee_proto.Extensions[extension] 799 extendee_proto.ClearExtension(extension) 800 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 801 self.assertTrue(group_list is not extendee_proto.Extensions[extension]) 802 # Shouldn't be allowed to do Extensions[extension] = 'a' 803 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 804 extension, 'a') 805 806 def testNestedExtensions(self): 807 extendee_proto = unittest_pb2.TestAllExtensions() 808 extension = unittest_pb2.TestRequired.single 809 810 # We just test the non-repeated case. 811 self.assertTrue(not extendee_proto.HasExtension(extension)) 812 required = extendee_proto.Extensions[extension] 813 self.assertEqual(0, required.a) 814 self.assertTrue(not extendee_proto.HasExtension(extension)) 815 required.a = 23 816 self.assertEqual(23, extendee_proto.Extensions[extension].a) 817 self.assertTrue(extendee_proto.HasExtension(extension)) 818 extendee_proto.ClearExtension(extension) 819 self.assertTrue(required is not extendee_proto.Extensions[extension]) 820 self.assertTrue(not extendee_proto.HasExtension(extension)) 821 822 # If message A directly contains message B, and 823 # a.HasField('b') is currently False, then mutating any 824 # extension in B should change a.HasField('b') to True 825 # (and so on up the object tree). 826 def testHasBitsForAncestorsOfExtendedMessage(self): 827 # Optional scalar extension. 828 toplevel = more_extensions_pb2.TopLevelMessage() 829 self.assertTrue(not toplevel.HasField('submessage')) 830 self.assertEqual(0, toplevel.submessage.Extensions[ 831 more_extensions_pb2.optional_int_extension]) 832 self.assertTrue(not toplevel.HasField('submessage')) 833 toplevel.submessage.Extensions[ 834 more_extensions_pb2.optional_int_extension] = 23 835 self.assertEqual(23, toplevel.submessage.Extensions[ 836 more_extensions_pb2.optional_int_extension]) 837 self.assertTrue(toplevel.HasField('submessage')) 838 839 # Repeated scalar extension. 840 toplevel = more_extensions_pb2.TopLevelMessage() 841 self.assertTrue(not toplevel.HasField('submessage')) 842 self.assertEqual([], toplevel.submessage.Extensions[ 843 more_extensions_pb2.repeated_int_extension]) 844 self.assertTrue(not toplevel.HasField('submessage')) 845 toplevel.submessage.Extensions[ 846 more_extensions_pb2.repeated_int_extension].append(23) 847 self.assertEqual([23], toplevel.submessage.Extensions[ 848 more_extensions_pb2.repeated_int_extension]) 849 self.assertTrue(toplevel.HasField('submessage')) 850 851 # Optional message extension. 852 toplevel = more_extensions_pb2.TopLevelMessage() 853 self.assertTrue(not toplevel.HasField('submessage')) 854 self.assertEqual(0, toplevel.submessage.Extensions[ 855 more_extensions_pb2.optional_message_extension].foreign_message_int) 856 self.assertTrue(not toplevel.HasField('submessage')) 857 toplevel.submessage.Extensions[ 858 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 859 self.assertEqual(23, toplevel.submessage.Extensions[ 860 more_extensions_pb2.optional_message_extension].foreign_message_int) 861 self.assertTrue(toplevel.HasField('submessage')) 862 863 # Repeated message extension. 864 toplevel = more_extensions_pb2.TopLevelMessage() 865 self.assertTrue(not toplevel.HasField('submessage')) 866 self.assertEqual(0, len(toplevel.submessage.Extensions[ 867 more_extensions_pb2.repeated_message_extension])) 868 self.assertTrue(not toplevel.HasField('submessage')) 869 foreign = toplevel.submessage.Extensions[ 870 more_extensions_pb2.repeated_message_extension].add() 871 self.assertTrue(foreign is toplevel.submessage.Extensions[ 872 more_extensions_pb2.repeated_message_extension][0]) 873 self.assertTrue(toplevel.HasField('submessage')) 874 875 def testDisconnectionAfterClearingEmptyMessage(self): 876 toplevel = more_extensions_pb2.TopLevelMessage() 877 extendee_proto = toplevel.submessage 878 extension = more_extensions_pb2.optional_message_extension 879 extension_proto = extendee_proto.Extensions[extension] 880 extendee_proto.ClearExtension(extension) 881 extension_proto.foreign_message_int = 23 882 883 self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) 884 885 def testExtensionFailureModes(self): 886 extendee_proto = unittest_pb2.TestAllExtensions() 887 888 # Try non-extension-handle arguments to HasExtension, 889 # ClearExtension(), and Extensions[]... 890 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 891 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 892 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 893 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 894 895 # Try something that *is* an extension handle, just not for 896 # this message... 897 unknown_handle = more_extensions_pb2.optional_int_extension 898 self.assertRaises(KeyError, extendee_proto.HasExtension, 899 unknown_handle) 900 self.assertRaises(KeyError, extendee_proto.ClearExtension, 901 unknown_handle) 902 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 903 unknown_handle) 904 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 905 unknown_handle, 5) 906 907 # Try call HasExtension() with a valid handle, but for a 908 # *repeated* field. (Just as with non-extension repeated 909 # fields, Has*() isn't supported for extension repeated fields). 910 self.assertRaises(KeyError, extendee_proto.HasExtension, 911 unittest_pb2.repeated_string_extension) 912 913 def testStaticParseFrom(self): 914 proto1 = unittest_pb2.TestAllTypes() 915 test_util.SetAllFields(proto1) 916 917 string1 = proto1.SerializeToString() 918 proto2 = unittest_pb2.TestAllTypes.FromString(string1) 919 920 # Messages should be equal. 921 self.assertEqual(proto2, proto1) 922 923 def testMergeFromSingularField(self): 924 # Test merge with just a singular field. 925 proto1 = unittest_pb2.TestAllTypes() 926 proto1.optional_int32 = 1 927 928 proto2 = unittest_pb2.TestAllTypes() 929 # This shouldn't get overwritten. 930 proto2.optional_string = 'value' 931 932 proto2.MergeFrom(proto1) 933 self.assertEqual(1, proto2.optional_int32) 934 self.assertEqual('value', proto2.optional_string) 935 936 def testMergeFromRepeatedField(self): 937 # Test merge with just a repeated field. 938 proto1 = unittest_pb2.TestAllTypes() 939 proto1.repeated_int32.append(1) 940 proto1.repeated_int32.append(2) 941 942 proto2 = unittest_pb2.TestAllTypes() 943 proto2.repeated_int32.append(0) 944 proto2.MergeFrom(proto1) 945 946 self.assertEqual(0, proto2.repeated_int32[0]) 947 self.assertEqual(1, proto2.repeated_int32[1]) 948 self.assertEqual(2, proto2.repeated_int32[2]) 949 950 def testMergeFromOptionalGroup(self): 951 # Test merge with an optional group. 952 proto1 = unittest_pb2.TestAllTypes() 953 proto1.optionalgroup.a = 12 954 proto2 = unittest_pb2.TestAllTypes() 955 proto2.MergeFrom(proto1) 956 self.assertEqual(12, proto2.optionalgroup.a) 957 958 def testMergeFromRepeatedNestedMessage(self): 959 # Test merge with a repeated nested message. 960 proto1 = unittest_pb2.TestAllTypes() 961 m = proto1.repeated_nested_message.add() 962 m.bb = 123 963 m = proto1.repeated_nested_message.add() 964 m.bb = 321 965 966 proto2 = unittest_pb2.TestAllTypes() 967 m = proto2.repeated_nested_message.add() 968 m.bb = 999 969 proto2.MergeFrom(proto1) 970 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 971 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 972 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 973 974 def testMergeFromAllFields(self): 975 # With all fields set. 976 proto1 = unittest_pb2.TestAllTypes() 977 test_util.SetAllFields(proto1) 978 proto2 = unittest_pb2.TestAllTypes() 979 proto2.MergeFrom(proto1) 980 981 # Messages should be equal. 982 self.assertEqual(proto2, proto1) 983 984 # Serialized string should be equal too. 985 string1 = proto1.SerializeToString() 986 string2 = proto2.SerializeToString() 987 self.assertEqual(string1, string2) 988 989 def testMergeFromExtensionsSingular(self): 990 proto1 = unittest_pb2.TestAllExtensions() 991 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 992 993 proto2 = unittest_pb2.TestAllExtensions() 994 proto2.MergeFrom(proto1) 995 self.assertEqual( 996 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 997 998 def testMergeFromExtensionsRepeated(self): 999 proto1 = unittest_pb2.TestAllExtensions() 1000 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1001 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1002 1003 proto2 = unittest_pb2.TestAllExtensions() 1004 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1005 proto2.MergeFrom(proto1) 1006 self.assertEqual( 1007 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1008 self.assertEqual( 1009 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1010 self.assertEqual( 1011 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1012 self.assertEqual( 1013 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1014 1015 def testMergeFromExtensionsNestedMessage(self): 1016 proto1 = unittest_pb2.TestAllExtensions() 1017 ext1 = proto1.Extensions[ 1018 unittest_pb2.repeated_nested_message_extension] 1019 m = ext1.add() 1020 m.bb = 222 1021 m = ext1.add() 1022 m.bb = 333 1023 1024 proto2 = unittest_pb2.TestAllExtensions() 1025 ext2 = proto2.Extensions[ 1026 unittest_pb2.repeated_nested_message_extension] 1027 m = ext2.add() 1028 m.bb = 111 1029 1030 proto2.MergeFrom(proto1) 1031 ext2 = proto2.Extensions[ 1032 unittest_pb2.repeated_nested_message_extension] 1033 self.assertEqual(3, len(ext2)) 1034 self.assertEqual(111, ext2[0].bb) 1035 self.assertEqual(222, ext2[1].bb) 1036 self.assertEqual(333, ext2[2].bb) 1037 1038 def testCopyFromSingularField(self): 1039 # Test copy with just a singular field. 1040 proto1 = unittest_pb2.TestAllTypes() 1041 proto1.optional_int32 = 1 1042 proto1.optional_string = 'important-text' 1043 1044 proto2 = unittest_pb2.TestAllTypes() 1045 proto2.optional_string = 'value' 1046 1047 proto2.CopyFrom(proto1) 1048 self.assertEqual(1, proto2.optional_int32) 1049 self.assertEqual('important-text', proto2.optional_string) 1050 1051 def testCopyFromRepeatedField(self): 1052 # Test copy with a repeated field. 1053 proto1 = unittest_pb2.TestAllTypes() 1054 proto1.repeated_int32.append(1) 1055 proto1.repeated_int32.append(2) 1056 1057 proto2 = unittest_pb2.TestAllTypes() 1058 proto2.repeated_int32.append(0) 1059 proto2.CopyFrom(proto1) 1060 1061 self.assertEqual(1, proto2.repeated_int32[0]) 1062 self.assertEqual(2, proto2.repeated_int32[1]) 1063 1064 def testCopyFromAllFields(self): 1065 # With all fields set. 1066 proto1 = unittest_pb2.TestAllTypes() 1067 test_util.SetAllFields(proto1) 1068 proto2 = unittest_pb2.TestAllTypes() 1069 proto2.CopyFrom(proto1) 1070 1071 # Messages should be equal. 1072 self.assertEqual(proto2, proto1) 1073 1074 # Serialized string should be equal too. 1075 string1 = proto1.SerializeToString() 1076 string2 = proto2.SerializeToString() 1077 self.assertEqual(string1, string2) 1078 1079 def testCopyFromSelf(self): 1080 proto1 = unittest_pb2.TestAllTypes() 1081 proto1.repeated_int32.append(1) 1082 proto1.optional_int32 = 2 1083 proto1.optional_string = 'important-text' 1084 1085 proto1.CopyFrom(proto1) 1086 self.assertEqual(1, proto1.repeated_int32[0]) 1087 self.assertEqual(2, proto1.optional_int32) 1088 self.assertEqual('important-text', proto1.optional_string) 1089 1090 def testClear(self): 1091 proto = unittest_pb2.TestAllTypes() 1092 test_util.SetAllFields(proto) 1093 # Clear the message. 1094 proto.Clear() 1095 self.assertEquals(proto.ByteSize(), 0) 1096 empty_proto = unittest_pb2.TestAllTypes() 1097 self.assertEquals(proto, empty_proto) 1098 1099 # Test if extensions which were set are cleared. 1100 proto = unittest_pb2.TestAllExtensions() 1101 test_util.SetAllExtensions(proto) 1102 # Clear the message. 1103 proto.Clear() 1104 self.assertEquals(proto.ByteSize(), 0) 1105 empty_proto = unittest_pb2.TestAllExtensions() 1106 self.assertEquals(proto, empty_proto) 1107 1108 def assertInitialized(self, proto): 1109 self.assertTrue(proto.IsInitialized()) 1110 # Neither method should raise an exception. 1111 proto.SerializeToString() 1112 proto.SerializePartialToString() 1113 1114 def assertNotInitialized(self, proto): 1115 self.assertFalse(proto.IsInitialized()) 1116 self.assertRaises(message.EncodeError, proto.SerializeToString) 1117 # "Partial" serialization doesn't care if message is uninitialized. 1118 proto.SerializePartialToString() 1119 1120 def testIsInitialized(self): 1121 # Trivial cases - all optional fields and extensions. 1122 proto = unittest_pb2.TestAllTypes() 1123 self.assertInitialized(proto) 1124 proto = unittest_pb2.TestAllExtensions() 1125 self.assertInitialized(proto) 1126 1127 # The case of uninitialized required fields. 1128 proto = unittest_pb2.TestRequired() 1129 self.assertNotInitialized(proto) 1130 proto.a = proto.b = proto.c = 2 1131 self.assertInitialized(proto) 1132 1133 # The case of uninitialized submessage. 1134 proto = unittest_pb2.TestRequiredForeign() 1135 self.assertInitialized(proto) 1136 proto.optional_message.a = 1 1137 self.assertNotInitialized(proto) 1138 proto.optional_message.b = 0 1139 proto.optional_message.c = 0 1140 self.assertInitialized(proto) 1141 1142 # Uninitialized repeated submessage. 1143 message1 = proto.repeated_message.add() 1144 self.assertNotInitialized(proto) 1145 message1.a = message1.b = message1.c = 0 1146 self.assertInitialized(proto) 1147 1148 # Uninitialized repeated group in an extension. 1149 proto = unittest_pb2.TestAllExtensions() 1150 extension = unittest_pb2.TestRequired.multi 1151 message1 = proto.Extensions[extension].add() 1152 message2 = proto.Extensions[extension].add() 1153 self.assertNotInitialized(proto) 1154 message1.a = 1 1155 message1.b = 1 1156 message1.c = 1 1157 self.assertNotInitialized(proto) 1158 message2.a = 2 1159 message2.b = 2 1160 message2.c = 2 1161 self.assertInitialized(proto) 1162 1163 # Uninitialized nonrepeated message in an extension. 1164 proto = unittest_pb2.TestAllExtensions() 1165 extension = unittest_pb2.TestRequired.single 1166 proto.Extensions[extension].a = 1 1167 self.assertNotInitialized(proto) 1168 proto.Extensions[extension].b = 2 1169 proto.Extensions[extension].c = 3 1170 self.assertInitialized(proto) 1171 1172 # Try passing an errors list. 1173 errors = [] 1174 proto = unittest_pb2.TestRequired() 1175 self.assertFalse(proto.IsInitialized(errors)) 1176 self.assertEqual(errors, ['a', 'b', 'c']) 1177 1178 def testStringUTF8Encoding(self): 1179 proto = unittest_pb2.TestAllTypes() 1180 1181 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 1182 self.assertRaises(TypeError, 1183 setattr, proto, 'optional_bytes', u'unicode object') 1184 1185 # Check that the default value is of python's 'unicode' type. 1186 self.assertEqual(type(proto.optional_string), unicode) 1187 1188 proto.optional_string = unicode('Testing') 1189 self.assertEqual(proto.optional_string, str('Testing')) 1190 1191 # Assign a value of type 'str' which can be encoded in UTF-8. 1192 proto.optional_string = str('Testing') 1193 self.assertEqual(proto.optional_string, unicode('Testing')) 1194 1195 # Values of type 'str' are also accepted as long as they can be encoded in 1196 # UTF-8. 1197 self.assertEqual(type(proto.optional_string), str) 1198 1199 # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. 1200 self.assertRaises(ValueError, 1201 setattr, proto, 'optional_string', str('a\x80a')) 1202 # Assign a 'str' object which contains a UTF-8 encoded string. 1203 self.assertRaises(ValueError, 1204 setattr, proto, 'optional_string', 'Тест') 1205 # No exception thrown. 1206 proto.optional_string = 'abc' 1207 1208 def testStringUTF8Serialization(self): 1209 proto = unittest_mset_pb2.TestMessageSet() 1210 extension_message = unittest_mset_pb2.TestMessageSetExtension2 1211 extension = extension_message.message_set_extension 1212 1213 test_utf8 = u'Тест' 1214 test_utf8_bytes = test_utf8.encode('utf-8') 1215 1216 # 'Test' in another language, using UTF-8 charset. 1217 proto.Extensions[extension].str = test_utf8 1218 1219 # Serialize using the MessageSet wire format (this is specified in the 1220 # .proto file). 1221 serialized = proto.SerializeToString() 1222 1223 # Check byte size. 1224 self.assertEqual(proto.ByteSize(), len(serialized)) 1225 1226 raw = unittest_mset_pb2.RawMessageSet() 1227 raw.MergeFromString(serialized) 1228 1229 message2 = unittest_mset_pb2.TestMessageSetExtension2() 1230 1231 self.assertEqual(1, len(raw.item)) 1232 # Check that the type_id is the same as the tag ID in the .proto file. 1233 self.assertEqual(raw.item[0].type_id, 1547769) 1234 1235 # Check the actually bytes on the wire. 1236 self.assertTrue( 1237 raw.item[0].message.endswith(test_utf8_bytes)) 1238 message2.MergeFromString(raw.item[0].message) 1239 1240 self.assertEqual(type(message2.str), unicode) 1241 self.assertEqual(message2.str, test_utf8) 1242 1243 # How about if the bytes on the wire aren't a valid UTF-8 encoded string. 1244 bytes = raw.item[0].message.replace( 1245 test_utf8_bytes, len(test_utf8_bytes) * '\xff') 1246 self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) 1247 1248 def testEmptyNestedMessage(self): 1249 proto = unittest_pb2.TestAllTypes() 1250 proto.optional_nested_message.MergeFrom( 1251 unittest_pb2.TestAllTypes.NestedMessage()) 1252 self.assertTrue(proto.HasField('optional_nested_message')) 1253 1254 proto = unittest_pb2.TestAllTypes() 1255 proto.optional_nested_message.CopyFrom( 1256 unittest_pb2.TestAllTypes.NestedMessage()) 1257 self.assertTrue(proto.HasField('optional_nested_message')) 1258 1259 proto = unittest_pb2.TestAllTypes() 1260 proto.optional_nested_message.MergeFromString('') 1261 self.assertTrue(proto.HasField('optional_nested_message')) 1262 1263 proto = unittest_pb2.TestAllTypes() 1264 proto.optional_nested_message.ParseFromString('') 1265 self.assertTrue(proto.HasField('optional_nested_message')) 1266 1267 serialized = proto.SerializeToString() 1268 proto2 = unittest_pb2.TestAllTypes() 1269 proto2.MergeFromString(serialized) 1270 self.assertTrue(proto2.HasField('optional_nested_message')) 1271 1272 def testSetInParent(self): 1273 proto = unittest_pb2.TestAllTypes() 1274 self.assertFalse(proto.HasField('optionalgroup')) 1275 proto.optionalgroup.SetInParent() 1276 self.assertTrue(proto.HasField('optionalgroup')) 1277 1278 1279# Since we had so many tests for protocol buffer equality, we broke these out 1280# into separate TestCase classes. 1281 1282 1283class TestAllTypesEqualityTest(unittest.TestCase): 1284 1285 def setUp(self): 1286 self.first_proto = unittest_pb2.TestAllTypes() 1287 self.second_proto = unittest_pb2.TestAllTypes() 1288 1289 def testSelfEquality(self): 1290 self.assertEqual(self.first_proto, self.first_proto) 1291 1292 def testEmptyProtosEqual(self): 1293 self.assertEqual(self.first_proto, self.second_proto) 1294 1295 1296class FullProtosEqualityTest(unittest.TestCase): 1297 1298 """Equality tests using completely-full protos as a starting point.""" 1299 1300 def setUp(self): 1301 self.first_proto = unittest_pb2.TestAllTypes() 1302 self.second_proto = unittest_pb2.TestAllTypes() 1303 test_util.SetAllFields(self.first_proto) 1304 test_util.SetAllFields(self.second_proto) 1305 1306 def testNoneNotEqual(self): 1307 self.assertNotEqual(self.first_proto, None) 1308 self.assertNotEqual(None, self.second_proto) 1309 1310 def testNotEqualToOtherMessage(self): 1311 third_proto = unittest_pb2.TestRequired() 1312 self.assertNotEqual(self.first_proto, third_proto) 1313 self.assertNotEqual(third_proto, self.second_proto) 1314 1315 def testAllFieldsFilledEquality(self): 1316 self.assertEqual(self.first_proto, self.second_proto) 1317 1318 def testNonRepeatedScalar(self): 1319 # Nonrepeated scalar field change should cause inequality. 1320 self.first_proto.optional_int32 += 1 1321 self.assertNotEqual(self.first_proto, self.second_proto) 1322 # ...as should clearing a field. 1323 self.first_proto.ClearField('optional_int32') 1324 self.assertNotEqual(self.first_proto, self.second_proto) 1325 1326 def testNonRepeatedComposite(self): 1327 # Change a nonrepeated composite field. 1328 self.first_proto.optional_nested_message.bb += 1 1329 self.assertNotEqual(self.first_proto, self.second_proto) 1330 self.first_proto.optional_nested_message.bb -= 1 1331 self.assertEqual(self.first_proto, self.second_proto) 1332 # Clear a field in the nested message. 1333 self.first_proto.optional_nested_message.ClearField('bb') 1334 self.assertNotEqual(self.first_proto, self.second_proto) 1335 self.first_proto.optional_nested_message.bb = ( 1336 self.second_proto.optional_nested_message.bb) 1337 self.assertEqual(self.first_proto, self.second_proto) 1338 # Remove the nested message entirely. 1339 self.first_proto.ClearField('optional_nested_message') 1340 self.assertNotEqual(self.first_proto, self.second_proto) 1341 1342 def testRepeatedScalar(self): 1343 # Change a repeated scalar field. 1344 self.first_proto.repeated_int32.append(5) 1345 self.assertNotEqual(self.first_proto, self.second_proto) 1346 self.first_proto.ClearField('repeated_int32') 1347 self.assertNotEqual(self.first_proto, self.second_proto) 1348 1349 def testRepeatedComposite(self): 1350 # Change value within a repeated composite field. 1351 self.first_proto.repeated_nested_message[0].bb += 1 1352 self.assertNotEqual(self.first_proto, self.second_proto) 1353 self.first_proto.repeated_nested_message[0].bb -= 1 1354 self.assertEqual(self.first_proto, self.second_proto) 1355 # Add a value to a repeated composite field. 1356 self.first_proto.repeated_nested_message.add() 1357 self.assertNotEqual(self.first_proto, self.second_proto) 1358 self.second_proto.repeated_nested_message.add() 1359 self.assertEqual(self.first_proto, self.second_proto) 1360 1361 def testNonRepeatedScalarHasBits(self): 1362 # Ensure that we test "has" bits as well as value for 1363 # nonrepeated scalar field. 1364 self.first_proto.ClearField('optional_int32') 1365 self.second_proto.optional_int32 = 0 1366 self.assertNotEqual(self.first_proto, self.second_proto) 1367 1368 def testNonRepeatedCompositeHasBits(self): 1369 # Ensure that we test "has" bits as well as value for 1370 # nonrepeated composite field. 1371 self.first_proto.ClearField('optional_nested_message') 1372 self.second_proto.optional_nested_message.ClearField('bb') 1373 self.assertNotEqual(self.first_proto, self.second_proto) 1374 # TODO(robinson): Replace next two lines with method 1375 # to set the "has" bit without changing the value, 1376 # if/when such a method exists. 1377 self.first_proto.optional_nested_message.bb = 0 1378 self.first_proto.optional_nested_message.ClearField('bb') 1379 self.assertEqual(self.first_proto, self.second_proto) 1380 1381 1382class ExtensionEqualityTest(unittest.TestCase): 1383 1384 def testExtensionEquality(self): 1385 first_proto = unittest_pb2.TestAllExtensions() 1386 second_proto = unittest_pb2.TestAllExtensions() 1387 self.assertEqual(first_proto, second_proto) 1388 test_util.SetAllExtensions(first_proto) 1389 self.assertNotEqual(first_proto, second_proto) 1390 test_util.SetAllExtensions(second_proto) 1391 self.assertEqual(first_proto, second_proto) 1392 1393 # Ensure that we check value equality. 1394 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 1395 self.assertNotEqual(first_proto, second_proto) 1396 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 1397 self.assertEqual(first_proto, second_proto) 1398 1399 # Ensure that we also look at "has" bits. 1400 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 1401 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1402 self.assertNotEqual(first_proto, second_proto) 1403 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1404 self.assertEqual(first_proto, second_proto) 1405 1406 # Ensure that differences in cached values 1407 # don't matter if "has" bits are both false. 1408 first_proto = unittest_pb2.TestAllExtensions() 1409 second_proto = unittest_pb2.TestAllExtensions() 1410 self.assertEqual( 1411 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 1412 self.assertEqual(first_proto, second_proto) 1413 1414 1415class MutualRecursionEqualityTest(unittest.TestCase): 1416 1417 def testEqualityWithMutualRecursion(self): 1418 first_proto = unittest_pb2.TestMutualRecursionA() 1419 second_proto = unittest_pb2.TestMutualRecursionA() 1420 self.assertEqual(first_proto, second_proto) 1421 first_proto.bb.a.bb.optional_int32 = 23 1422 self.assertNotEqual(first_proto, second_proto) 1423 second_proto.bb.a.bb.optional_int32 = 23 1424 self.assertEqual(first_proto, second_proto) 1425 1426 1427class ByteSizeTest(unittest.TestCase): 1428 1429 def setUp(self): 1430 self.proto = unittest_pb2.TestAllTypes() 1431 self.extended_proto = more_extensions_pb2.ExtendedMessage() 1432 self.packed_proto = unittest_pb2.TestPackedTypes() 1433 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 1434 1435 def Size(self): 1436 return self.proto.ByteSize() 1437 1438 def testEmptyMessage(self): 1439 self.assertEqual(0, self.proto.ByteSize()) 1440 1441 def testVarints(self): 1442 def Test(i, expected_varint_size): 1443 self.proto.Clear() 1444 self.proto.optional_int64 = i 1445 # Add one to the varint size for the tag info 1446 # for tag 1. 1447 self.assertEqual(expected_varint_size + 1, self.Size()) 1448 Test(0, 1) 1449 Test(1, 1) 1450 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 1451 Test((1 << i) - 1, num_bytes) 1452 Test(-1, 10) 1453 Test(-2, 10) 1454 Test(-(1 << 63), 10) 1455 1456 def testStrings(self): 1457 self.proto.optional_string = '' 1458 # Need one byte for tag info (tag #14), and one byte for length. 1459 self.assertEqual(2, self.Size()) 1460 1461 self.proto.optional_string = 'abc' 1462 # Need one byte for tag info (tag #14), and one byte for length. 1463 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 1464 1465 self.proto.optional_string = 'x' * 128 1466 # Need one byte for tag info (tag #14), and TWO bytes for length. 1467 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 1468 1469 def testOtherNumerics(self): 1470 self.proto.optional_fixed32 = 1234 1471 # One byte for tag and 4 bytes for fixed32. 1472 self.assertEqual(5, self.Size()) 1473 self.proto = unittest_pb2.TestAllTypes() 1474 1475 self.proto.optional_fixed64 = 1234 1476 # One byte for tag and 8 bytes for fixed64. 1477 self.assertEqual(9, self.Size()) 1478 self.proto = unittest_pb2.TestAllTypes() 1479 1480 self.proto.optional_float = 1.234 1481 # One byte for tag and 4 bytes for float. 1482 self.assertEqual(5, self.Size()) 1483 self.proto = unittest_pb2.TestAllTypes() 1484 1485 self.proto.optional_double = 1.234 1486 # One byte for tag and 8 bytes for float. 1487 self.assertEqual(9, self.Size()) 1488 self.proto = unittest_pb2.TestAllTypes() 1489 1490 self.proto.optional_sint32 = 64 1491 # One byte for tag and 2 bytes for zig-zag-encoded 64. 1492 self.assertEqual(3, self.Size()) 1493 self.proto = unittest_pb2.TestAllTypes() 1494 1495 def testComposites(self): 1496 # 3 bytes. 1497 self.proto.optional_nested_message.bb = (1 << 14) 1498 # Plus one byte for bb tag. 1499 # Plus 1 byte for optional_nested_message serialized size. 1500 # Plus two bytes for optional_nested_message tag. 1501 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 1502 1503 def testGroups(self): 1504 # 4 bytes. 1505 self.proto.optionalgroup.a = (1 << 21) 1506 # Plus two bytes for |a| tag. 1507 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 1508 self.assertEqual(4 + 2 + 2*2, self.Size()) 1509 1510 def testRepeatedScalars(self): 1511 self.proto.repeated_int32.append(10) # 1 byte. 1512 self.proto.repeated_int32.append(128) # 2 bytes. 1513 # Also need 2 bytes for each entry for tag. 1514 self.assertEqual(1 + 2 + 2*2, self.Size()) 1515 1516 def testRepeatedScalarsExtend(self): 1517 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 1518 # Also need 2 bytes for each entry for tag. 1519 self.assertEqual(1 + 2 + 2*2, self.Size()) 1520 1521 def testRepeatedScalarsRemove(self): 1522 self.proto.repeated_int32.append(10) # 1 byte. 1523 self.proto.repeated_int32.append(128) # 2 bytes. 1524 # Also need 2 bytes for each entry for tag. 1525 self.assertEqual(1 + 2 + 2*2, self.Size()) 1526 self.proto.repeated_int32.remove(128) 1527 self.assertEqual(1 + 2, self.Size()) 1528 1529 def testRepeatedComposites(self): 1530 # Empty message. 2 bytes tag plus 1 byte length. 1531 foreign_message_0 = self.proto.repeated_nested_message.add() 1532 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 1533 foreign_message_1 = self.proto.repeated_nested_message.add() 1534 foreign_message_1.bb = 7 1535 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 1536 1537 def testRepeatedCompositesDelete(self): 1538 # Empty message. 2 bytes tag plus 1 byte length. 1539 foreign_message_0 = self.proto.repeated_nested_message.add() 1540 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 1541 foreign_message_1 = self.proto.repeated_nested_message.add() 1542 foreign_message_1.bb = 9 1543 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 1544 1545 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 1546 del self.proto.repeated_nested_message[0] 1547 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 1548 1549 # Now add a new message. 1550 foreign_message_2 = self.proto.repeated_nested_message.add() 1551 foreign_message_2.bb = 12 1552 1553 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 1554 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 1555 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 1556 1557 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 1558 del self.proto.repeated_nested_message[1] 1559 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 1560 1561 del self.proto.repeated_nested_message[0] 1562 self.assertEqual(0, self.Size()) 1563 1564 def testRepeatedGroups(self): 1565 # 2-byte START_GROUP plus 2-byte END_GROUP. 1566 group_0 = self.proto.repeatedgroup.add() 1567 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 1568 # plus 2-byte END_GROUP. 1569 group_1 = self.proto.repeatedgroup.add() 1570 group_1.a = 7 1571 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 1572 1573 def testExtensions(self): 1574 proto = unittest_pb2.TestAllExtensions() 1575 self.assertEqual(0, proto.ByteSize()) 1576 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 1577 proto.Extensions[extension] = 23 1578 # 1 byte for tag, 1 byte for value. 1579 self.assertEqual(2, proto.ByteSize()) 1580 1581 def testCacheInvalidationForNonrepeatedScalar(self): 1582 # Test non-extension. 1583 self.proto.optional_int32 = 1 1584 self.assertEqual(2, self.proto.ByteSize()) 1585 self.proto.optional_int32 = 128 1586 self.assertEqual(3, self.proto.ByteSize()) 1587 self.proto.ClearField('optional_int32') 1588 self.assertEqual(0, self.proto.ByteSize()) 1589 1590 # Test within extension. 1591 extension = more_extensions_pb2.optional_int_extension 1592 self.extended_proto.Extensions[extension] = 1 1593 self.assertEqual(2, self.extended_proto.ByteSize()) 1594 self.extended_proto.Extensions[extension] = 128 1595 self.assertEqual(3, self.extended_proto.ByteSize()) 1596 self.extended_proto.ClearExtension(extension) 1597 self.assertEqual(0, self.extended_proto.ByteSize()) 1598 1599 def testCacheInvalidationForRepeatedScalar(self): 1600 # Test non-extension. 1601 self.proto.repeated_int32.append(1) 1602 self.assertEqual(3, self.proto.ByteSize()) 1603 self.proto.repeated_int32.append(1) 1604 self.assertEqual(6, self.proto.ByteSize()) 1605 self.proto.repeated_int32[1] = 128 1606 self.assertEqual(7, self.proto.ByteSize()) 1607 self.proto.ClearField('repeated_int32') 1608 self.assertEqual(0, self.proto.ByteSize()) 1609 1610 # Test within extension. 1611 extension = more_extensions_pb2.repeated_int_extension 1612 repeated = self.extended_proto.Extensions[extension] 1613 repeated.append(1) 1614 self.assertEqual(2, self.extended_proto.ByteSize()) 1615 repeated.append(1) 1616 self.assertEqual(4, self.extended_proto.ByteSize()) 1617 repeated[1] = 128 1618 self.assertEqual(5, self.extended_proto.ByteSize()) 1619 self.extended_proto.ClearExtension(extension) 1620 self.assertEqual(0, self.extended_proto.ByteSize()) 1621 1622 def testCacheInvalidationForNonrepeatedMessage(self): 1623 # Test non-extension. 1624 self.proto.optional_foreign_message.c = 1 1625 self.assertEqual(5, self.proto.ByteSize()) 1626 self.proto.optional_foreign_message.c = 128 1627 self.assertEqual(6, self.proto.ByteSize()) 1628 self.proto.optional_foreign_message.ClearField('c') 1629 self.assertEqual(3, self.proto.ByteSize()) 1630 self.proto.ClearField('optional_foreign_message') 1631 self.assertEqual(0, self.proto.ByteSize()) 1632 child = self.proto.optional_foreign_message 1633 self.proto.ClearField('optional_foreign_message') 1634 child.c = 128 1635 self.assertEqual(0, self.proto.ByteSize()) 1636 1637 # Test within extension. 1638 extension = more_extensions_pb2.optional_message_extension 1639 child = self.extended_proto.Extensions[extension] 1640 self.assertEqual(0, self.extended_proto.ByteSize()) 1641 child.foreign_message_int = 1 1642 self.assertEqual(4, self.extended_proto.ByteSize()) 1643 child.foreign_message_int = 128 1644 self.assertEqual(5, self.extended_proto.ByteSize()) 1645 self.extended_proto.ClearExtension(extension) 1646 self.assertEqual(0, self.extended_proto.ByteSize()) 1647 1648 def testCacheInvalidationForRepeatedMessage(self): 1649 # Test non-extension. 1650 child0 = self.proto.repeated_foreign_message.add() 1651 self.assertEqual(3, self.proto.ByteSize()) 1652 self.proto.repeated_foreign_message.add() 1653 self.assertEqual(6, self.proto.ByteSize()) 1654 child0.c = 1 1655 self.assertEqual(8, self.proto.ByteSize()) 1656 self.proto.ClearField('repeated_foreign_message') 1657 self.assertEqual(0, self.proto.ByteSize()) 1658 1659 # Test within extension. 1660 extension = more_extensions_pb2.repeated_message_extension 1661 child_list = self.extended_proto.Extensions[extension] 1662 child0 = child_list.add() 1663 self.assertEqual(2, self.extended_proto.ByteSize()) 1664 child_list.add() 1665 self.assertEqual(4, self.extended_proto.ByteSize()) 1666 child0.foreign_message_int = 1 1667 self.assertEqual(6, self.extended_proto.ByteSize()) 1668 child0.ClearField('foreign_message_int') 1669 self.assertEqual(4, self.extended_proto.ByteSize()) 1670 self.extended_proto.ClearExtension(extension) 1671 self.assertEqual(0, self.extended_proto.ByteSize()) 1672 1673 def testPackedRepeatedScalars(self): 1674 self.assertEqual(0, self.packed_proto.ByteSize()) 1675 1676 self.packed_proto.packed_int32.append(10) # 1 byte. 1677 self.packed_proto.packed_int32.append(128) # 2 bytes. 1678 # The tag is 2 bytes (the field number is 90), and the varint 1679 # storing the length is 1 byte. 1680 int_size = 1 + 2 + 3 1681 self.assertEqual(int_size, self.packed_proto.ByteSize()) 1682 1683 self.packed_proto.packed_double.append(4.2) # 8 bytes 1684 self.packed_proto.packed_double.append(3.25) # 8 bytes 1685 # 2 more tag bytes, 1 more length byte. 1686 double_size = 8 + 8 + 3 1687 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 1688 1689 self.packed_proto.ClearField('packed_int32') 1690 self.assertEqual(double_size, self.packed_proto.ByteSize()) 1691 1692 def testPackedExtensions(self): 1693 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 1694 extension = self.packed_extended_proto.Extensions[ 1695 unittest_pb2.packed_fixed32_extension] 1696 extension.extend([1, 2, 3, 4]) # 16 bytes 1697 # Tag is 3 bytes. 1698 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 1699 1700 1701# TODO(robinson): We need cross-language serialization consistency tests. 1702# Issues to be sure to cover include: 1703# * Handling of unrecognized tags ("uninterpreted_bytes"). 1704# * Handling of MessageSets. 1705# * Consistent ordering of tags in the wire format, 1706# including ordering between extensions and non-extension 1707# fields. 1708# * Consistent serialization of negative numbers, especially 1709# negative int32s. 1710# * Handling of empty submessages (with and without "has" 1711# bits set). 1712 1713class SerializationTest(unittest.TestCase): 1714 1715 def testSerializeEmtpyMessage(self): 1716 first_proto = unittest_pb2.TestAllTypes() 1717 second_proto = unittest_pb2.TestAllTypes() 1718 serialized = first_proto.SerializeToString() 1719 self.assertEqual(first_proto.ByteSize(), len(serialized)) 1720 second_proto.MergeFromString(serialized) 1721 self.assertEqual(first_proto, second_proto) 1722 1723 def testSerializeAllFields(self): 1724 first_proto = unittest_pb2.TestAllTypes() 1725 second_proto = unittest_pb2.TestAllTypes() 1726 test_util.SetAllFields(first_proto) 1727 serialized = first_proto.SerializeToString() 1728 self.assertEqual(first_proto.ByteSize(), len(serialized)) 1729 second_proto.MergeFromString(serialized) 1730 self.assertEqual(first_proto, second_proto) 1731 1732 def testSerializeAllExtensions(self): 1733 first_proto = unittest_pb2.TestAllExtensions() 1734 second_proto = unittest_pb2.TestAllExtensions() 1735 test_util.SetAllExtensions(first_proto) 1736 serialized = first_proto.SerializeToString() 1737 second_proto.MergeFromString(serialized) 1738 self.assertEqual(first_proto, second_proto) 1739 1740 def testSerializeNegativeValues(self): 1741 first_proto = unittest_pb2.TestAllTypes() 1742 1743 first_proto.optional_int32 = -1 1744 first_proto.optional_int64 = -(2 << 40) 1745 first_proto.optional_sint32 = -3 1746 first_proto.optional_sint64 = -(4 << 40) 1747 first_proto.optional_sfixed32 = -5 1748 first_proto.optional_sfixed64 = -(6 << 40) 1749 1750 second_proto = unittest_pb2.TestAllTypes.FromString( 1751 first_proto.SerializeToString()) 1752 1753 self.assertEqual(first_proto, second_proto) 1754 1755 def testParseTruncated(self): 1756 first_proto = unittest_pb2.TestAllTypes() 1757 test_util.SetAllFields(first_proto) 1758 serialized = first_proto.SerializeToString() 1759 1760 for truncation_point in xrange(len(serialized) + 1): 1761 try: 1762 second_proto = unittest_pb2.TestAllTypes() 1763 unknown_fields = unittest_pb2.TestEmptyMessage() 1764 pos = second_proto._InternalParse(serialized, 0, truncation_point) 1765 # If we didn't raise an error then we read exactly the amount expected. 1766 self.assertEqual(truncation_point, pos) 1767 1768 # Parsing to unknown fields should not throw if parsing to known fields 1769 # did not. 1770 try: 1771 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 1772 self.assertEqual(truncation_point, pos2) 1773 except message.DecodeError: 1774 self.fail('Parsing unknown fields failed when parsing known fields ' 1775 'did not.') 1776 except message.DecodeError: 1777 # Parsing unknown fields should also fail. 1778 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 1779 serialized, 0, truncation_point) 1780 1781 def testCanonicalSerializationOrder(self): 1782 proto = more_messages_pb2.OutOfOrderFields() 1783 # These are also their tag numbers. Even though we're setting these in 1784 # reverse-tag order AND they're listed in reverse tag-order in the .proto 1785 # file, they should nonetheless be serialized in tag order. 1786 proto.optional_sint32 = 5 1787 proto.Extensions[more_messages_pb2.optional_uint64] = 4 1788 proto.optional_uint32 = 3 1789 proto.Extensions[more_messages_pb2.optional_int64] = 2 1790 proto.optional_int32 = 1 1791 serialized = proto.SerializeToString() 1792 self.assertEqual(proto.ByteSize(), len(serialized)) 1793 d = _MiniDecoder(serialized) 1794 ReadTag = d.ReadFieldNumberAndWireType 1795 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 1796 self.assertEqual(1, d.ReadInt32()) 1797 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 1798 self.assertEqual(2, d.ReadInt64()) 1799 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 1800 self.assertEqual(3, d.ReadUInt32()) 1801 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 1802 self.assertEqual(4, d.ReadUInt64()) 1803 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 1804 self.assertEqual(5, d.ReadSInt32()) 1805 1806 def testCanonicalSerializationOrderSameAsCpp(self): 1807 # Copy of the same test we use for C++. 1808 proto = unittest_pb2.TestFieldOrderings() 1809 test_util.SetAllFieldsAndExtensions(proto) 1810 serialized = proto.SerializeToString() 1811 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 1812 1813 def testMergeFromStringWhenFieldsAlreadySet(self): 1814 first_proto = unittest_pb2.TestAllTypes() 1815 first_proto.repeated_string.append('foobar') 1816 first_proto.optional_int32 = 23 1817 first_proto.optional_nested_message.bb = 42 1818 serialized = first_proto.SerializeToString() 1819 1820 second_proto = unittest_pb2.TestAllTypes() 1821 second_proto.repeated_string.append('baz') 1822 second_proto.optional_int32 = 100 1823 second_proto.optional_nested_message.bb = 999 1824 1825 second_proto.MergeFromString(serialized) 1826 # Ensure that we append to repeated fields. 1827 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 1828 # Ensure that we overwrite nonrepeatd scalars. 1829 self.assertEqual(23, second_proto.optional_int32) 1830 # Ensure that we recursively call MergeFromString() on 1831 # submessages. 1832 self.assertEqual(42, second_proto.optional_nested_message.bb) 1833 1834 def testMessageSetWireFormat(self): 1835 proto = unittest_mset_pb2.TestMessageSet() 1836 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 1837 extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 1838 extension1 = extension_message1.message_set_extension 1839 extension2 = extension_message2.message_set_extension 1840 proto.Extensions[extension1].i = 123 1841 proto.Extensions[extension2].str = 'foo' 1842 1843 # Serialize using the MessageSet wire format (this is specified in the 1844 # .proto file). 1845 serialized = proto.SerializeToString() 1846 1847 raw = unittest_mset_pb2.RawMessageSet() 1848 self.assertEqual(False, 1849 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 1850 raw.MergeFromString(serialized) 1851 self.assertEqual(2, len(raw.item)) 1852 1853 message1 = unittest_mset_pb2.TestMessageSetExtension1() 1854 message1.MergeFromString(raw.item[0].message) 1855 self.assertEqual(123, message1.i) 1856 1857 message2 = unittest_mset_pb2.TestMessageSetExtension2() 1858 message2.MergeFromString(raw.item[1].message) 1859 self.assertEqual('foo', message2.str) 1860 1861 # Deserialize using the MessageSet wire format. 1862 proto2 = unittest_mset_pb2.TestMessageSet() 1863 proto2.MergeFromString(serialized) 1864 self.assertEqual(123, proto2.Extensions[extension1].i) 1865 self.assertEqual('foo', proto2.Extensions[extension2].str) 1866 1867 # Check byte size. 1868 self.assertEqual(proto2.ByteSize(), len(serialized)) 1869 self.assertEqual(proto.ByteSize(), len(serialized)) 1870 1871 def testMessageSetWireFormatUnknownExtension(self): 1872 # Create a message using the message set wire format with an unknown 1873 # message. 1874 raw = unittest_mset_pb2.RawMessageSet() 1875 1876 # Add an item. 1877 item = raw.item.add() 1878 item.type_id = 1545008 1879 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 1880 message1 = unittest_mset_pb2.TestMessageSetExtension1() 1881 message1.i = 12345 1882 item.message = message1.SerializeToString() 1883 1884 # Add a second, unknown extension. 1885 item = raw.item.add() 1886 item.type_id = 1545009 1887 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 1888 message1 = unittest_mset_pb2.TestMessageSetExtension1() 1889 message1.i = 12346 1890 item.message = message1.SerializeToString() 1891 1892 # Add another unknown extension. 1893 item = raw.item.add() 1894 item.type_id = 1545010 1895 message1 = unittest_mset_pb2.TestMessageSetExtension2() 1896 message1.str = 'foo' 1897 item.message = message1.SerializeToString() 1898 1899 serialized = raw.SerializeToString() 1900 1901 # Parse message using the message set wire format. 1902 proto = unittest_mset_pb2.TestMessageSet() 1903 proto.MergeFromString(serialized) 1904 1905 # Check that the message parsed well. 1906 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 1907 extension1 = extension_message1.message_set_extension 1908 self.assertEquals(12345, proto.Extensions[extension1].i) 1909 1910 def testUnknownFields(self): 1911 proto = unittest_pb2.TestAllTypes() 1912 test_util.SetAllFields(proto) 1913 1914 serialized = proto.SerializeToString() 1915 1916 # The empty message should be parsable with all of the fields 1917 # unknown. 1918 proto2 = unittest_pb2.TestEmptyMessage() 1919 1920 # Parsing this message should succeed. 1921 proto2.MergeFromString(serialized) 1922 1923 # Now test with a int64 field set. 1924 proto = unittest_pb2.TestAllTypes() 1925 proto.optional_int64 = 0x0fffffffffffffff 1926 serialized = proto.SerializeToString() 1927 # The empty message should be parsable with all of the fields 1928 # unknown. 1929 proto2 = unittest_pb2.TestEmptyMessage() 1930 # Parsing this message should succeed. 1931 proto2.MergeFromString(serialized) 1932 1933 def _CheckRaises(self, exc_class, callable_obj, exception): 1934 """This method checks if the excpetion type and message are as expected.""" 1935 try: 1936 callable_obj() 1937 except exc_class, ex: 1938 # Check if the exception message is the right one. 1939 self.assertEqual(exception, str(ex)) 1940 return 1941 else: 1942 raise self.failureException('%s not raised' % str(exc_class)) 1943 1944 def testSerializeUninitialized(self): 1945 proto = unittest_pb2.TestRequired() 1946 self._CheckRaises( 1947 message.EncodeError, 1948 proto.SerializeToString, 1949 'Message is missing required fields: a,b,c') 1950 # Shouldn't raise exceptions. 1951 partial = proto.SerializePartialToString() 1952 1953 proto.a = 1 1954 self._CheckRaises( 1955 message.EncodeError, 1956 proto.SerializeToString, 1957 'Message is missing required fields: b,c') 1958 # Shouldn't raise exceptions. 1959 partial = proto.SerializePartialToString() 1960 1961 proto.b = 2 1962 self._CheckRaises( 1963 message.EncodeError, 1964 proto.SerializeToString, 1965 'Message is missing required fields: c') 1966 # Shouldn't raise exceptions. 1967 partial = proto.SerializePartialToString() 1968 1969 proto.c = 3 1970 serialized = proto.SerializeToString() 1971 # Shouldn't raise exceptions. 1972 partial = proto.SerializePartialToString() 1973 1974 proto2 = unittest_pb2.TestRequired() 1975 proto2.MergeFromString(serialized) 1976 self.assertEqual(1, proto2.a) 1977 self.assertEqual(2, proto2.b) 1978 self.assertEqual(3, proto2.c) 1979 proto2.ParseFromString(partial) 1980 self.assertEqual(1, proto2.a) 1981 self.assertEqual(2, proto2.b) 1982 self.assertEqual(3, proto2.c) 1983 1984 def testSerializeUninitializedSubMessage(self): 1985 proto = unittest_pb2.TestRequiredForeign() 1986 1987 # Sub-message doesn't exist yet, so this succeeds. 1988 proto.SerializeToString() 1989 1990 proto.optional_message.a = 1 1991 self._CheckRaises( 1992 message.EncodeError, 1993 proto.SerializeToString, 1994 'Message is missing required fields: ' 1995 'optional_message.b,optional_message.c') 1996 1997 proto.optional_message.b = 2 1998 proto.optional_message.c = 3 1999 proto.SerializeToString() 2000 2001 proto.repeated_message.add().a = 1 2002 proto.repeated_message.add().b = 2 2003 self._CheckRaises( 2004 message.EncodeError, 2005 proto.SerializeToString, 2006 'Message is missing required fields: ' 2007 'repeated_message[0].b,repeated_message[0].c,' 2008 'repeated_message[1].a,repeated_message[1].c') 2009 2010 proto.repeated_message[0].b = 2 2011 proto.repeated_message[0].c = 3 2012 proto.repeated_message[1].a = 1 2013 proto.repeated_message[1].c = 3 2014 proto.SerializeToString() 2015 2016 def testSerializeAllPackedFields(self): 2017 first_proto = unittest_pb2.TestPackedTypes() 2018 second_proto = unittest_pb2.TestPackedTypes() 2019 test_util.SetAllPackedFields(first_proto) 2020 serialized = first_proto.SerializeToString() 2021 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2022 bytes_read = second_proto.MergeFromString(serialized) 2023 self.assertEqual(second_proto.ByteSize(), bytes_read) 2024 self.assertEqual(first_proto, second_proto) 2025 2026 def testSerializeAllPackedExtensions(self): 2027 first_proto = unittest_pb2.TestPackedExtensions() 2028 second_proto = unittest_pb2.TestPackedExtensions() 2029 test_util.SetAllPackedExtensions(first_proto) 2030 serialized = first_proto.SerializeToString() 2031 bytes_read = second_proto.MergeFromString(serialized) 2032 self.assertEqual(second_proto.ByteSize(), bytes_read) 2033 self.assertEqual(first_proto, second_proto) 2034 2035 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2036 first_proto = unittest_pb2.TestPackedTypes() 2037 first_proto.packed_int32.extend([1, 2]) 2038 first_proto.packed_double.append(3.0) 2039 serialized = first_proto.SerializeToString() 2040 2041 second_proto = unittest_pb2.TestPackedTypes() 2042 second_proto.packed_int32.append(3) 2043 second_proto.packed_double.extend([1.0, 2.0]) 2044 second_proto.packed_sint32.append(4) 2045 2046 second_proto.MergeFromString(serialized) 2047 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2048 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2049 self.assertEqual([4], second_proto.packed_sint32) 2050 2051 def testPackedFieldsWireFormat(self): 2052 proto = unittest_pb2.TestPackedTypes() 2053 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 2054 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 2055 proto.packed_float.append(2.0) # 4 bytes, will be before double 2056 serialized = proto.SerializeToString() 2057 self.assertEqual(proto.ByteSize(), len(serialized)) 2058 d = _MiniDecoder(serialized) 2059 ReadTag = d.ReadFieldNumberAndWireType 2060 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2061 self.assertEqual(1+1+1+2, d.ReadInt32()) 2062 self.assertEqual(1, d.ReadInt32()) 2063 self.assertEqual(2, d.ReadInt32()) 2064 self.assertEqual(150, d.ReadInt32()) 2065 self.assertEqual(3, d.ReadInt32()) 2066 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2067 self.assertEqual(4, d.ReadInt32()) 2068 self.assertEqual(2.0, d.ReadFloat()) 2069 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2070 self.assertEqual(8+8, d.ReadInt32()) 2071 self.assertEqual(1.0, d.ReadDouble()) 2072 self.assertEqual(1000.0, d.ReadDouble()) 2073 self.assertTrue(d.EndOfStream()) 2074 2075 def testParsePackedFromUnpacked(self): 2076 unpacked = unittest_pb2.TestUnpackedTypes() 2077 test_util.SetAllUnpackedFields(unpacked) 2078 packed = unittest_pb2.TestPackedTypes() 2079 packed.MergeFromString(unpacked.SerializeToString()) 2080 expected = unittest_pb2.TestPackedTypes() 2081 test_util.SetAllPackedFields(expected) 2082 self.assertEqual(expected, packed) 2083 2084 def testParseUnpackedFromPacked(self): 2085 packed = unittest_pb2.TestPackedTypes() 2086 test_util.SetAllPackedFields(packed) 2087 unpacked = unittest_pb2.TestUnpackedTypes() 2088 unpacked.MergeFromString(packed.SerializeToString()) 2089 expected = unittest_pb2.TestUnpackedTypes() 2090 test_util.SetAllUnpackedFields(expected) 2091 self.assertEqual(expected, unpacked) 2092 2093 def testFieldNumbers(self): 2094 proto = unittest_pb2.TestAllTypes() 2095 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 2096 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 2097 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 2098 self.assertEqual( 2099 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 2100 self.assertEqual( 2101 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 2102 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 2103 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 2104 self.assertEqual( 2105 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 2106 self.assertEqual( 2107 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 2108 2109 def testExtensionFieldNumbers(self): 2110 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 2111 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 2112 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 2113 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 2114 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 2115 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 2116 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 2117 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 2118 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 2119 self.assertEqual( 2120 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 2121 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 2122 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2123 21) 2124 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 2125 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 2126 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 2127 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 2128 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 2129 self.assertEqual( 2130 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 2131 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 2132 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2133 51) 2134 2135 def testInitKwargs(self): 2136 proto = unittest_pb2.TestAllTypes( 2137 optional_int32=1, 2138 optional_string='foo', 2139 optional_bool=True, 2140 optional_bytes='bar', 2141 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 2142 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 2143 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 2144 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 2145 repeated_int32=[1, 2, 3]) 2146 self.assertTrue(proto.IsInitialized()) 2147 self.assertTrue(proto.HasField('optional_int32')) 2148 self.assertTrue(proto.HasField('optional_string')) 2149 self.assertTrue(proto.HasField('optional_bool')) 2150 self.assertTrue(proto.HasField('optional_bytes')) 2151 self.assertTrue(proto.HasField('optional_nested_message')) 2152 self.assertTrue(proto.HasField('optional_foreign_message')) 2153 self.assertTrue(proto.HasField('optional_nested_enum')) 2154 self.assertTrue(proto.HasField('optional_foreign_enum')) 2155 self.assertEqual(1, proto.optional_int32) 2156 self.assertEqual('foo', proto.optional_string) 2157 self.assertEqual(True, proto.optional_bool) 2158 self.assertEqual('bar', proto.optional_bytes) 2159 self.assertEqual(1, proto.optional_nested_message.bb) 2160 self.assertEqual(1, proto.optional_foreign_message.c) 2161 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 2162 proto.optional_nested_enum) 2163 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 2164 self.assertEqual([1, 2, 3], proto.repeated_int32) 2165 2166 def testInitArgsUnknownFieldName(self): 2167 def InitalizeEmptyMessageWithExtraKeywordArg(): 2168 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 2169 self._CheckRaises(ValueError, 2170 InitalizeEmptyMessageWithExtraKeywordArg, 2171 'Protocol message has no "unknown" field.') 2172 2173 def testInitRequiredKwargs(self): 2174 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 2175 self.assertTrue(proto.IsInitialized()) 2176 self.assertTrue(proto.HasField('a')) 2177 self.assertTrue(proto.HasField('b')) 2178 self.assertTrue(proto.HasField('c')) 2179 self.assertTrue(not proto.HasField('dummy2')) 2180 self.assertEqual(1, proto.a) 2181 self.assertEqual(1, proto.b) 2182 self.assertEqual(1, proto.c) 2183 2184 def testInitRequiredForeignKwargs(self): 2185 proto = unittest_pb2.TestRequiredForeign( 2186 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 2187 self.assertTrue(proto.IsInitialized()) 2188 self.assertTrue(proto.HasField('optional_message')) 2189 self.assertTrue(proto.optional_message.IsInitialized()) 2190 self.assertTrue(proto.optional_message.HasField('a')) 2191 self.assertTrue(proto.optional_message.HasField('b')) 2192 self.assertTrue(proto.optional_message.HasField('c')) 2193 self.assertTrue(not proto.optional_message.HasField('dummy2')) 2194 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 2195 proto.optional_message) 2196 self.assertEqual(1, proto.optional_message.a) 2197 self.assertEqual(1, proto.optional_message.b) 2198 self.assertEqual(1, proto.optional_message.c) 2199 2200 def testInitRepeatedKwargs(self): 2201 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 2202 self.assertTrue(proto.IsInitialized()) 2203 self.assertEqual(1, proto.repeated_int32[0]) 2204 self.assertEqual(2, proto.repeated_int32[1]) 2205 self.assertEqual(3, proto.repeated_int32[2]) 2206 2207 2208class OptionsTest(unittest.TestCase): 2209 2210 def testMessageOptions(self): 2211 proto = unittest_mset_pb2.TestMessageSet() 2212 self.assertEqual(True, 2213 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2214 proto = unittest_pb2.TestAllTypes() 2215 self.assertEqual(False, 2216 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2217 2218 def testPackedOptions(self): 2219 proto = unittest_pb2.TestAllTypes() 2220 proto.optional_int32 = 1 2221 proto.optional_double = 3.0 2222 for field_descriptor, _ in proto.ListFields(): 2223 self.assertEqual(False, field_descriptor.GetOptions().packed) 2224 2225 proto = unittest_pb2.TestPackedTypes() 2226 proto.packed_int32.append(1) 2227 proto.packed_double.append(3.0) 2228 for field_descriptor, _ in proto.ListFields(): 2229 self.assertEqual(True, field_descriptor.GetOptions().packed) 2230 self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, 2231 field_descriptor.label) 2232 2233 2234 2235if __name__ == '__main__': 2236 unittest.main() 2237