1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# Keep it Python2.5 compatible for GAE. 32# 33# Copyright 2007 Google Inc. All Rights Reserved. 34# 35# This code is meant to work on Python 2.4 and above only. 36# 37# TODO(robinson): Helpers for verbose, common checks like seeing if a 38# descriptor's cpp_type is CPPTYPE_MESSAGE. 39 40"""Contains a metaclass and helper functions used to create 41protocol message classes from Descriptor objects at runtime. 42 43Recall that a metaclass is the "type" of a class. 44(A class is to a metaclass what an instance is to a class.) 45 46In this case, we use the GeneratedProtocolMessageType metaclass 47to inject all the useful functionality into the classes 48output by the protocol compiler at compile-time. 49 50The upshot of all this is that the real implementation 51details for ALL pure-Python protocol buffers are *here in 52this file*. 53""" 54 55__author__ = 'robinson@google.com (Will Robinson)' 56 57import sys 58if sys.version_info[0] < 3: 59 try: 60 from cStringIO import StringIO as BytesIO 61 except ImportError: 62 from StringIO import StringIO as BytesIO 63 import copy_reg as copyreg 64else: 65 from io import BytesIO 66 import copyreg 67import struct 68import weakref 69 70# We use "as" to avoid name collisions with variables. 71from google.protobuf.internal import containers 72from google.protobuf.internal import decoder 73from google.protobuf.internal import encoder 74from google.protobuf.internal import enum_type_wrapper 75from google.protobuf.internal import message_listener as message_listener_mod 76from google.protobuf.internal import type_checkers 77from google.protobuf.internal import wire_format 78from google.protobuf import descriptor as descriptor_mod 79from google.protobuf import message as message_mod 80from google.protobuf import text_format 81 82_FieldDescriptor = descriptor_mod.FieldDescriptor 83 84 85def NewMessage(bases, descriptor, dictionary): 86 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 87 _AddSlots(descriptor, dictionary) 88 return bases 89 90 91def InitMessage(descriptor, cls): 92 cls._decoders_by_tag = {} 93 cls._extensions_by_name = {} 94 cls._extensions_by_number = {} 95 if (descriptor.has_options and 96 descriptor.GetOptions().message_set_wire_format): 97 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 98 decoder.MessageSetItemDecoder(cls._extensions_by_number), None) 99 100 # Attach stuff to each FieldDescriptor for quick lookup later on. 101 for field in descriptor.fields: 102 _AttachFieldHelpers(cls, field) 103 104 _AddEnumValues(descriptor, cls) 105 _AddInitMethod(descriptor, cls) 106 _AddPropertiesForFields(descriptor, cls) 107 _AddPropertiesForExtensions(descriptor, cls) 108 _AddStaticMethods(cls) 109 _AddMessageMethods(descriptor, cls) 110 _AddPrivateHelperMethods(descriptor, cls) 111 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) 112 113 114# Stateless helpers for GeneratedProtocolMessageType below. 115# Outside clients should not access these directly. 116# 117# I opted not to make any of these methods on the metaclass, to make it more 118# clear that I'm not really using any state there and to keep clients from 119# thinking that they have direct access to these construction helpers. 120 121 122def _PropertyName(proto_field_name): 123 """Returns the name of the public property attribute which 124 clients can use to get and (in some cases) set the value 125 of a protocol message field. 126 127 Args: 128 proto_field_name: The protocol message field name, exactly 129 as it appears (or would appear) in a .proto file. 130 """ 131 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 132 # nnorwitz makes my day by writing: 133 # """ 134 # FYI. See the keyword module in the stdlib. This could be as simple as: 135 # 136 # if keyword.iskeyword(proto_field_name): 137 # return proto_field_name + "_" 138 # return proto_field_name 139 # """ 140 # Kenton says: The above is a BAD IDEA. People rely on being able to use 141 # getattr() and setattr() to reflectively manipulate field values. If we 142 # rename the properties, then every such user has to also make sure to apply 143 # the same transformation. Note that currently if you name a field "yield", 144 # you can still access it just fine using getattr/setattr -- it's not even 145 # that cumbersome to do so. 146 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 147 # position. 148 return proto_field_name 149 150 151def _VerifyExtensionHandle(message, extension_handle): 152 """Verify that the given extension handle is valid.""" 153 154 if not isinstance(extension_handle, _FieldDescriptor): 155 raise KeyError('HasExtension() expects an extension handle, got: %s' % 156 extension_handle) 157 158 if not extension_handle.is_extension: 159 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) 160 161 if not extension_handle.containing_type: 162 raise KeyError('"%s" is missing a containing_type.' 163 % extension_handle.full_name) 164 165 if extension_handle.containing_type is not message.DESCRIPTOR: 166 raise KeyError('Extension "%s" extends message type "%s", but this ' 167 'message is of type "%s".' % 168 (extension_handle.full_name, 169 extension_handle.containing_type.full_name, 170 message.DESCRIPTOR.full_name)) 171 172 173def _AddSlots(message_descriptor, dictionary): 174 """Adds a __slots__ entry to dictionary, containing the names of all valid 175 attributes for this message type. 176 177 Args: 178 message_descriptor: A Descriptor instance describing this message type. 179 dictionary: Class dictionary to which we'll add a '__slots__' entry. 180 """ 181 dictionary['__slots__'] = ['_cached_byte_size', 182 '_cached_byte_size_dirty', 183 '_fields', 184 '_unknown_fields', 185 '_is_present_in_parent', 186 '_listener', 187 '_listener_for_children', 188 '__weakref__', 189 '_oneofs'] 190 191 192def _IsMessageSetExtension(field): 193 return (field.is_extension and 194 field.containing_type.has_options and 195 field.containing_type.GetOptions().message_set_wire_format and 196 field.type == _FieldDescriptor.TYPE_MESSAGE and 197 field.message_type == field.extension_scope and 198 field.label == _FieldDescriptor.LABEL_OPTIONAL) 199 200 201def _AttachFieldHelpers(cls, field_descriptor): 202 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 203 is_packed = (field_descriptor.has_options and 204 field_descriptor.GetOptions().packed) 205 206 if _IsMessageSetExtension(field_descriptor): 207 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 208 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 209 else: 210 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 211 field_descriptor.number, is_repeated, is_packed) 212 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 213 field_descriptor.number, is_repeated, is_packed) 214 215 field_descriptor._encoder = field_encoder 216 field_descriptor._sizer = sizer 217 field_descriptor._default_constructor = _DefaultValueConstructorForField( 218 field_descriptor) 219 220 def AddDecoder(wiretype, is_packed): 221 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 222 cls._decoders_by_tag[tag_bytes] = ( 223 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( 224 field_descriptor.number, is_repeated, is_packed, 225 field_descriptor, field_descriptor._default_constructor), 226 field_descriptor if field_descriptor.containing_oneof is not None 227 else None) 228 229 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 230 False) 231 232 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 233 # To support wire compatibility of adding packed = true, add a decoder for 234 # packed values regardless of the field's options. 235 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 236 237 238def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 239 extension_dict = descriptor.extensions_by_name 240 for extension_name, extension_field in extension_dict.iteritems(): 241 assert extension_name not in dictionary 242 dictionary[extension_name] = extension_field 243 244 245def _AddEnumValues(descriptor, cls): 246 """Sets class-level attributes for all enum fields defined in this message. 247 248 Also exporting a class-level object that can name enum values. 249 250 Args: 251 descriptor: Descriptor object for this message type. 252 cls: Class we're constructing for this message type. 253 """ 254 for enum_type in descriptor.enum_types: 255 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 256 for enum_value in enum_type.values: 257 setattr(cls, enum_value.name, enum_value.number) 258 259 260def _DefaultValueConstructorForField(field): 261 """Returns a function which returns a default value for a field. 262 263 Args: 264 field: FieldDescriptor object for this field. 265 266 The returned function has one argument: 267 message: Message instance containing this field, or a weakref proxy 268 of same. 269 270 That function in turn returns a default value for this field. The default 271 value may refer back to |message| via a weak reference. 272 """ 273 274 if field.label == _FieldDescriptor.LABEL_REPEATED: 275 if field.has_default_value and field.default_value != []: 276 raise ValueError('Repeated field default value not empty list: %s' % ( 277 field.default_value)) 278 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 279 # We can't look at _concrete_class yet since it might not have 280 # been set. (Depends on order in which we initialize the classes). 281 message_type = field.message_type 282 def MakeRepeatedMessageDefault(message): 283 return containers.RepeatedCompositeFieldContainer( 284 message._listener_for_children, field.message_type) 285 return MakeRepeatedMessageDefault 286 else: 287 type_checker = type_checkers.GetTypeChecker(field) 288 def MakeRepeatedScalarDefault(message): 289 return containers.RepeatedScalarFieldContainer( 290 message._listener_for_children, type_checker) 291 return MakeRepeatedScalarDefault 292 293 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 294 # _concrete_class may not yet be initialized. 295 message_type = field.message_type 296 def MakeSubMessageDefault(message): 297 result = message_type._concrete_class() 298 result._SetListener(message._listener_for_children) 299 return result 300 return MakeSubMessageDefault 301 302 def MakeScalarDefault(message): 303 # TODO(protobuf-team): This may be broken since there may not be 304 # default_value. Combine with has_default_value somehow. 305 return field.default_value 306 return MakeScalarDefault 307 308 309def _AddInitMethod(message_descriptor, cls): 310 """Adds an __init__ method to cls.""" 311 fields = message_descriptor.fields 312 def init(self, **kwargs): 313 self._cached_byte_size = 0 314 self._cached_byte_size_dirty = len(kwargs) > 0 315 self._fields = {} 316 # Contains a mapping from oneof field descriptors to the descriptor 317 # of the currently set field in that oneof field. 318 self._oneofs = {} 319 320 # _unknown_fields is () when empty for efficiency, and will be turned into 321 # a list if fields are added. 322 self._unknown_fields = () 323 self._is_present_in_parent = False 324 self._listener = message_listener_mod.NullMessageListener() 325 self._listener_for_children = _Listener(self) 326 for field_name, field_value in kwargs.iteritems(): 327 field = _GetFieldByName(message_descriptor, field_name) 328 if field is None: 329 raise TypeError("%s() got an unexpected keyword argument '%s'" % 330 (message_descriptor.name, field_name)) 331 if field.label == _FieldDescriptor.LABEL_REPEATED: 332 copy = field._default_constructor(self) 333 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 334 for val in field_value: 335 copy.add().MergeFrom(val) 336 else: # Scalar 337 copy.extend(field_value) 338 self._fields[field] = copy 339 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 340 copy = field._default_constructor(self) 341 copy.MergeFrom(field_value) 342 self._fields[field] = copy 343 else: 344 setattr(self, field_name, field_value) 345 346 init.__module__ = None 347 init.__doc__ = None 348 cls.__init__ = init 349 350 351def _GetFieldByName(message_descriptor, field_name): 352 """Returns a field descriptor by field name. 353 354 Args: 355 message_descriptor: A Descriptor describing all fields in message. 356 field_name: The name of the field to retrieve. 357 Returns: 358 The field descriptor associated with the field name. 359 """ 360 try: 361 return message_descriptor.fields_by_name[field_name] 362 except KeyError: 363 raise ValueError('Protocol message has no "%s" field.' % field_name) 364 365 366def _AddPropertiesForFields(descriptor, cls): 367 """Adds properties for all fields in this protocol message type.""" 368 for field in descriptor.fields: 369 _AddPropertiesForField(field, cls) 370 371 if descriptor.is_extendable: 372 # _ExtensionDict is just an adaptor with no state so we allocate a new one 373 # every time it is accessed. 374 cls.Extensions = property(lambda self: _ExtensionDict(self)) 375 376 377def _AddPropertiesForField(field, cls): 378 """Adds a public property for a protocol message field. 379 Clients can use this property to get and (in the case 380 of non-repeated scalar fields) directly set the value 381 of a protocol message field. 382 383 Args: 384 field: A FieldDescriptor for this field. 385 cls: The class we're constructing. 386 """ 387 # Catch it if we add other types that we should 388 # handle specially here. 389 assert _FieldDescriptor.MAX_CPPTYPE == 10 390 391 constant_name = field.name.upper() + "_FIELD_NUMBER" 392 setattr(cls, constant_name, field.number) 393 394 if field.label == _FieldDescriptor.LABEL_REPEATED: 395 _AddPropertiesForRepeatedField(field, cls) 396 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 397 _AddPropertiesForNonRepeatedCompositeField(field, cls) 398 else: 399 _AddPropertiesForNonRepeatedScalarField(field, cls) 400 401 402def _AddPropertiesForRepeatedField(field, cls): 403 """Adds a public property for a "repeated" protocol message field. Clients 404 can use this property to get the value of the field, which will be either a 405 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see 406 below). 407 408 Note that when clients add values to these containers, we perform 409 type-checking in the case of repeated scalar fields, and we also set any 410 necessary "has" bits as a side-effect. 411 412 Args: 413 field: A FieldDescriptor for this field. 414 cls: The class we're constructing. 415 """ 416 proto_field_name = field.name 417 property_name = _PropertyName(proto_field_name) 418 419 def getter(self): 420 field_value = self._fields.get(field) 421 if field_value is None: 422 # Construct a new object to represent this field. 423 field_value = field._default_constructor(self) 424 425 # Atomically check if another thread has preempted us and, if not, swap 426 # in the new object we just created. If someone has preempted us, we 427 # take that object and discard ours. 428 # WARNING: We are relying on setdefault() being atomic. This is true 429 # in CPython but we haven't investigated others. This warning appears 430 # in several other locations in this file. 431 field_value = self._fields.setdefault(field, field_value) 432 return field_value 433 getter.__module__ = None 434 getter.__doc__ = 'Getter for %s.' % proto_field_name 435 436 # We define a setter just so we can throw an exception with a more 437 # helpful error message. 438 def setter(self, new_value): 439 raise AttributeError('Assignment not allowed to repeated field ' 440 '"%s" in protocol message object.' % proto_field_name) 441 442 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 443 setattr(cls, property_name, property(getter, setter, doc=doc)) 444 445 446def _AddPropertiesForNonRepeatedScalarField(field, cls): 447 """Adds a public property for a nonrepeated, scalar protocol message field. 448 Clients can use this property to get and directly set the value of the field. 449 Note that when the client sets the value of a field by using this property, 450 all necessary "has" bits are set as a side-effect, and we also perform 451 type-checking. 452 453 Args: 454 field: A FieldDescriptor for this field. 455 cls: The class we're constructing. 456 """ 457 proto_field_name = field.name 458 property_name = _PropertyName(proto_field_name) 459 type_checker = type_checkers.GetTypeChecker(field) 460 default_value = field.default_value 461 valid_values = set() 462 463 def getter(self): 464 # TODO(protobuf-team): This may be broken since there may not be 465 # default_value. Combine with has_default_value somehow. 466 return self._fields.get(field, default_value) 467 getter.__module__ = None 468 getter.__doc__ = 'Getter for %s.' % proto_field_name 469 def field_setter(self, new_value): 470 # pylint: disable=protected-access 471 self._fields[field] = type_checker.CheckValue(new_value) 472 # Check _cached_byte_size_dirty inline to improve performance, since scalar 473 # setters are called frequently. 474 if not self._cached_byte_size_dirty: 475 self._Modified() 476 477 if field.containing_oneof is not None: 478 def setter(self, new_value): 479 field_setter(self, new_value) 480 self._UpdateOneofState(field) 481 else: 482 setter = field_setter 483 484 setter.__module__ = None 485 setter.__doc__ = 'Setter for %s.' % proto_field_name 486 487 # Add a property to encapsulate the getter/setter. 488 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 489 setattr(cls, property_name, property(getter, setter, doc=doc)) 490 491 492def _AddPropertiesForNonRepeatedCompositeField(field, cls): 493 """Adds a public property for a nonrepeated, composite protocol message field. 494 A composite field is a "group" or "message" field. 495 496 Clients can use this property to get the value of the field, but cannot 497 assign to the property directly. 498 499 Args: 500 field: A FieldDescriptor for this field. 501 cls: The class we're constructing. 502 """ 503 # TODO(robinson): Remove duplication with similar method 504 # for non-repeated scalars. 505 proto_field_name = field.name 506 property_name = _PropertyName(proto_field_name) 507 508 # TODO(komarek): Can anyone explain to me why we cache the message_type this 509 # way, instead of referring to field.message_type inside of getter(self)? 510 # What if someone sets message_type later on (which makes for simpler 511 # dyanmic proto descriptor and class creation code). 512 message_type = field.message_type 513 514 def getter(self): 515 field_value = self._fields.get(field) 516 if field_value is None: 517 # Construct a new object to represent this field. 518 field_value = message_type._concrete_class() # use field.message_type? 519 field_value._SetListener( 520 _OneofListener(self, field) 521 if field.containing_oneof is not None 522 else self._listener_for_children) 523 524 # Atomically check if another thread has preempted us and, if not, swap 525 # in the new object we just created. If someone has preempted us, we 526 # take that object and discard ours. 527 # WARNING: We are relying on setdefault() being atomic. This is true 528 # in CPython but we haven't investigated others. This warning appears 529 # in several other locations in this file. 530 field_value = self._fields.setdefault(field, field_value) 531 return field_value 532 getter.__module__ = None 533 getter.__doc__ = 'Getter for %s.' % proto_field_name 534 535 # We define a setter just so we can throw an exception with a more 536 # helpful error message. 537 def setter(self, new_value): 538 raise AttributeError('Assignment not allowed to composite field ' 539 '"%s" in protocol message object.' % proto_field_name) 540 541 # Add a property to encapsulate the getter. 542 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 543 setattr(cls, property_name, property(getter, setter, doc=doc)) 544 545 546def _AddPropertiesForExtensions(descriptor, cls): 547 """Adds properties for all fields in this protocol message type.""" 548 extension_dict = descriptor.extensions_by_name 549 for extension_name, extension_field in extension_dict.iteritems(): 550 constant_name = extension_name.upper() + "_FIELD_NUMBER" 551 setattr(cls, constant_name, extension_field.number) 552 553 554def _AddStaticMethods(cls): 555 # TODO(robinson): This probably needs to be thread-safe(?) 556 def RegisterExtension(extension_handle): 557 extension_handle.containing_type = cls.DESCRIPTOR 558 _AttachFieldHelpers(cls, extension_handle) 559 560 # Try to insert our extension, failing if an extension with the same number 561 # already exists. 562 actual_handle = cls._extensions_by_number.setdefault( 563 extension_handle.number, extension_handle) 564 if actual_handle is not extension_handle: 565 raise AssertionError( 566 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 567 'field number %d.' % 568 (extension_handle.full_name, actual_handle.full_name, 569 cls.DESCRIPTOR.full_name, extension_handle.number)) 570 571 cls._extensions_by_name[extension_handle.full_name] = extension_handle 572 573 handle = extension_handle # avoid line wrapping 574 if _IsMessageSetExtension(handle): 575 # MessageSet extension. Also register under type name. 576 cls._extensions_by_name[ 577 extension_handle.message_type.full_name] = extension_handle 578 579 cls.RegisterExtension = staticmethod(RegisterExtension) 580 581 def FromString(s): 582 message = cls() 583 message.MergeFromString(s) 584 return message 585 cls.FromString = staticmethod(FromString) 586 587 588def _IsPresent(item): 589 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 590 value should be included in the list returned by ListFields().""" 591 592 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 593 return bool(item[1]) 594 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 595 return item[1]._is_present_in_parent 596 else: 597 return True 598 599 600def _AddListFieldsMethod(message_descriptor, cls): 601 """Helper for _AddMessageMethods().""" 602 603 def ListFields(self): 604 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] 605 all_fields.sort(key = lambda item: item[0].number) 606 return all_fields 607 608 cls.ListFields = ListFields 609 610 611def _AddHasFieldMethod(message_descriptor, cls): 612 """Helper for _AddMessageMethods().""" 613 614 singular_fields = {} 615 for field in message_descriptor.fields: 616 if field.label != _FieldDescriptor.LABEL_REPEATED: 617 singular_fields[field.name] = field 618 # Fields inside oneofs are never repeated (enforced by the compiler). 619 for field in message_descriptor.oneofs: 620 singular_fields[field.name] = field 621 622 def HasField(self, field_name): 623 try: 624 field = singular_fields[field_name] 625 except KeyError: 626 raise ValueError( 627 'Protocol message has no singular "%s" field.' % field_name) 628 629 if isinstance(field, descriptor_mod.OneofDescriptor): 630 try: 631 return HasField(self, self._oneofs[field].name) 632 except KeyError: 633 return False 634 else: 635 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 636 value = self._fields.get(field) 637 return value is not None and value._is_present_in_parent 638 else: 639 return field in self._fields 640 641 cls.HasField = HasField 642 643 644def _AddClearFieldMethod(message_descriptor, cls): 645 """Helper for _AddMessageMethods().""" 646 def ClearField(self, field_name): 647 try: 648 field = message_descriptor.fields_by_name[field_name] 649 except KeyError: 650 try: 651 field = message_descriptor.oneofs_by_name[field_name] 652 if field in self._oneofs: 653 field = self._oneofs[field] 654 else: 655 return 656 except KeyError: 657 raise ValueError('Protocol message has no "%s" field.' % field_name) 658 659 if field in self._fields: 660 # Note: If the field is a sub-message, its listener will still point 661 # at us. That's fine, because the worst than can happen is that it 662 # will call _Modified() and invalidate our byte size. Big deal. 663 del self._fields[field] 664 665 if self._oneofs.get(field.containing_oneof, None) is field: 666 del self._oneofs[field.containing_oneof] 667 668 # Always call _Modified() -- even if nothing was changed, this is 669 # a mutating method, and thus calling it should cause the field to become 670 # present in the parent message. 671 self._Modified() 672 673 cls.ClearField = ClearField 674 675 676def _AddClearExtensionMethod(cls): 677 """Helper for _AddMessageMethods().""" 678 def ClearExtension(self, extension_handle): 679 _VerifyExtensionHandle(self, extension_handle) 680 681 # Similar to ClearField(), above. 682 if extension_handle in self._fields: 683 del self._fields[extension_handle] 684 self._Modified() 685 cls.ClearExtension = ClearExtension 686 687 688def _AddClearMethod(message_descriptor, cls): 689 """Helper for _AddMessageMethods().""" 690 def Clear(self): 691 # Clear fields. 692 self._fields = {} 693 self._unknown_fields = () 694 self._Modified() 695 cls.Clear = Clear 696 697 698def _AddHasExtensionMethod(cls): 699 """Helper for _AddMessageMethods().""" 700 def HasExtension(self, extension_handle): 701 _VerifyExtensionHandle(self, extension_handle) 702 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 703 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 704 705 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 706 value = self._fields.get(extension_handle) 707 return value is not None and value._is_present_in_parent 708 else: 709 return extension_handle in self._fields 710 cls.HasExtension = HasExtension 711 712 713def _AddEqualsMethod(message_descriptor, cls): 714 """Helper for _AddMessageMethods().""" 715 def __eq__(self, other): 716 if (not isinstance(other, message_mod.Message) or 717 other.DESCRIPTOR != self.DESCRIPTOR): 718 return False 719 720 if self is other: 721 return True 722 723 if not self.ListFields() == other.ListFields(): 724 return False 725 726 # Sort unknown fields because their order shouldn't affect equality test. 727 unknown_fields = list(self._unknown_fields) 728 unknown_fields.sort() 729 other_unknown_fields = list(other._unknown_fields) 730 other_unknown_fields.sort() 731 732 return unknown_fields == other_unknown_fields 733 734 cls.__eq__ = __eq__ 735 736 737def _AddStrMethod(message_descriptor, cls): 738 """Helper for _AddMessageMethods().""" 739 def __str__(self): 740 return text_format.MessageToString(self) 741 cls.__str__ = __str__ 742 743 744def _AddUnicodeMethod(unused_message_descriptor, cls): 745 """Helper for _AddMessageMethods().""" 746 747 def __unicode__(self): 748 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 749 cls.__unicode__ = __unicode__ 750 751 752def _AddSetListenerMethod(cls): 753 """Helper for _AddMessageMethods().""" 754 def SetListener(self, listener): 755 if listener is None: 756 self._listener = message_listener_mod.NullMessageListener() 757 else: 758 self._listener = listener 759 cls._SetListener = SetListener 760 761 762def _BytesForNonRepeatedElement(value, field_number, field_type): 763 """Returns the number of bytes needed to serialize a non-repeated element. 764 The returned byte count includes space for tag information and any 765 other additional space associated with serializing value. 766 767 Args: 768 value: Value we're serializing. 769 field_number: Field number of this value. (Since the field number 770 is stored as part of a varint-encoded tag, this has an impact 771 on the total bytes required to serialize the value). 772 field_type: The type of the field. One of the TYPE_* constants 773 within FieldDescriptor. 774 """ 775 try: 776 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 777 return fn(field_number, value) 778 except KeyError: 779 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 780 781 782def _AddByteSizeMethod(message_descriptor, cls): 783 """Helper for _AddMessageMethods().""" 784 785 def ByteSize(self): 786 if not self._cached_byte_size_dirty: 787 return self._cached_byte_size 788 789 size = 0 790 for field_descriptor, field_value in self.ListFields(): 791 size += field_descriptor._sizer(field_value) 792 793 for tag_bytes, value_bytes in self._unknown_fields: 794 size += len(tag_bytes) + len(value_bytes) 795 796 self._cached_byte_size = size 797 self._cached_byte_size_dirty = False 798 self._listener_for_children.dirty = False 799 return size 800 801 cls.ByteSize = ByteSize 802 803 804def _AddSerializeToStringMethod(message_descriptor, cls): 805 """Helper for _AddMessageMethods().""" 806 807 def SerializeToString(self): 808 # Check if the message has all of its required fields set. 809 errors = [] 810 if not self.IsInitialized(): 811 raise message_mod.EncodeError( 812 'Message %s is missing required fields: %s' % ( 813 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 814 return self.SerializePartialToString() 815 cls.SerializeToString = SerializeToString 816 817 818def _AddSerializePartialToStringMethod(message_descriptor, cls): 819 """Helper for _AddMessageMethods().""" 820 821 def SerializePartialToString(self): 822 out = BytesIO() 823 self._InternalSerialize(out.write) 824 return out.getvalue() 825 cls.SerializePartialToString = SerializePartialToString 826 827 def InternalSerialize(self, write_bytes): 828 for field_descriptor, field_value in self.ListFields(): 829 field_descriptor._encoder(write_bytes, field_value) 830 for tag_bytes, value_bytes in self._unknown_fields: 831 write_bytes(tag_bytes) 832 write_bytes(value_bytes) 833 cls._InternalSerialize = InternalSerialize 834 835 836def _AddMergeFromStringMethod(message_descriptor, cls): 837 """Helper for _AddMessageMethods().""" 838 def MergeFromString(self, serialized): 839 length = len(serialized) 840 try: 841 if self._InternalParse(serialized, 0, length) != length: 842 # The only reason _InternalParse would return early is if it 843 # encountered an end-group tag. 844 raise message_mod.DecodeError('Unexpected end-group tag.') 845 except (IndexError, TypeError): 846 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 847 raise message_mod.DecodeError('Truncated message.') 848 except struct.error, e: 849 raise message_mod.DecodeError(e) 850 return length # Return this for legacy reasons. 851 cls.MergeFromString = MergeFromString 852 853 local_ReadTag = decoder.ReadTag 854 local_SkipField = decoder.SkipField 855 decoders_by_tag = cls._decoders_by_tag 856 857 def InternalParse(self, buffer, pos, end): 858 self._Modified() 859 field_dict = self._fields 860 unknown_field_list = self._unknown_fields 861 while pos != end: 862 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 863 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 864 if field_decoder is None: 865 value_start_pos = new_pos 866 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 867 if new_pos == -1: 868 return pos 869 if not unknown_field_list: 870 unknown_field_list = self._unknown_fields = [] 871 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) 872 pos = new_pos 873 else: 874 pos = field_decoder(buffer, new_pos, end, self, field_dict) 875 if field_desc: 876 self._UpdateOneofState(field_desc) 877 return pos 878 cls._InternalParse = InternalParse 879 880 881def _AddIsInitializedMethod(message_descriptor, cls): 882 """Adds the IsInitialized and FindInitializationError methods to the 883 protocol message class.""" 884 885 required_fields = [field for field in message_descriptor.fields 886 if field.label == _FieldDescriptor.LABEL_REQUIRED] 887 888 def IsInitialized(self, errors=None): 889 """Checks if all required fields of a message are set. 890 891 Args: 892 errors: A list which, if provided, will be populated with the field 893 paths of all missing required fields. 894 895 Returns: 896 True iff the specified message has all required fields set. 897 """ 898 899 # Performance is critical so we avoid HasField() and ListFields(). 900 901 for field in required_fields: 902 if (field not in self._fields or 903 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 904 not self._fields[field]._is_present_in_parent)): 905 if errors is not None: 906 errors.extend(self.FindInitializationErrors()) 907 return False 908 909 for field, value in list(self._fields.items()): # dict can change size! 910 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 911 if field.label == _FieldDescriptor.LABEL_REPEATED: 912 for element in value: 913 if not element.IsInitialized(): 914 if errors is not None: 915 errors.extend(self.FindInitializationErrors()) 916 return False 917 elif value._is_present_in_parent and not value.IsInitialized(): 918 if errors is not None: 919 errors.extend(self.FindInitializationErrors()) 920 return False 921 922 return True 923 924 cls.IsInitialized = IsInitialized 925 926 def FindInitializationErrors(self): 927 """Finds required fields which are not initialized. 928 929 Returns: 930 A list of strings. Each string is a path to an uninitialized field from 931 the top-level message, e.g. "foo.bar[5].baz". 932 """ 933 934 errors = [] # simplify things 935 936 for field in required_fields: 937 if not self.HasField(field.name): 938 errors.append(field.name) 939 940 for field, value in self.ListFields(): 941 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 942 if field.is_extension: 943 name = "(%s)" % field.full_name 944 else: 945 name = field.name 946 947 if field.label == _FieldDescriptor.LABEL_REPEATED: 948 for i in xrange(len(value)): 949 element = value[i] 950 prefix = "%s[%d]." % (name, i) 951 sub_errors = element.FindInitializationErrors() 952 errors += [ prefix + error for error in sub_errors ] 953 else: 954 prefix = name + "." 955 sub_errors = value.FindInitializationErrors() 956 errors += [ prefix + error for error in sub_errors ] 957 958 return errors 959 960 cls.FindInitializationErrors = FindInitializationErrors 961 962 963def _AddMergeFromMethod(cls): 964 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 965 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 966 967 def MergeFrom(self, msg): 968 if not isinstance(msg, cls): 969 raise TypeError( 970 "Parameter to MergeFrom() must be instance of same class: " 971 "expected %s got %s." % (cls.__name__, type(msg).__name__)) 972 973 assert msg is not self 974 self._Modified() 975 976 fields = self._fields 977 978 for field, value in msg._fields.iteritems(): 979 if field.label == LABEL_REPEATED: 980 field_value = fields.get(field) 981 if field_value is None: 982 # Construct a new object to represent this field. 983 field_value = field._default_constructor(self) 984 fields[field] = field_value 985 field_value.MergeFrom(value) 986 elif field.cpp_type == CPPTYPE_MESSAGE: 987 if value._is_present_in_parent: 988 field_value = fields.get(field) 989 if field_value is None: 990 # Construct a new object to represent this field. 991 field_value = field._default_constructor(self) 992 fields[field] = field_value 993 field_value.MergeFrom(value) 994 else: 995 self._fields[field] = value 996 997 if msg._unknown_fields: 998 if not self._unknown_fields: 999 self._unknown_fields = [] 1000 self._unknown_fields.extend(msg._unknown_fields) 1001 1002 cls.MergeFrom = MergeFrom 1003 1004 1005def _AddWhichOneofMethod(message_descriptor, cls): 1006 def WhichOneof(self, oneof_name): 1007 """Returns the name of the currently set field inside a oneof, or None.""" 1008 try: 1009 field = message_descriptor.oneofs_by_name[oneof_name] 1010 except KeyError: 1011 raise ValueError( 1012 'Protocol message has no oneof "%s" field.' % oneof_name) 1013 1014 nested_field = self._oneofs.get(field, None) 1015 if nested_field is not None and self.HasField(nested_field.name): 1016 return nested_field.name 1017 else: 1018 return None 1019 1020 cls.WhichOneof = WhichOneof 1021 1022 1023def _AddMessageMethods(message_descriptor, cls): 1024 """Adds implementations of all Message methods to cls.""" 1025 _AddListFieldsMethod(message_descriptor, cls) 1026 _AddHasFieldMethod(message_descriptor, cls) 1027 _AddClearFieldMethod(message_descriptor, cls) 1028 if message_descriptor.is_extendable: 1029 _AddClearExtensionMethod(cls) 1030 _AddHasExtensionMethod(cls) 1031 _AddClearMethod(message_descriptor, cls) 1032 _AddEqualsMethod(message_descriptor, cls) 1033 _AddStrMethod(message_descriptor, cls) 1034 _AddUnicodeMethod(message_descriptor, cls) 1035 _AddSetListenerMethod(cls) 1036 _AddByteSizeMethod(message_descriptor, cls) 1037 _AddSerializeToStringMethod(message_descriptor, cls) 1038 _AddSerializePartialToStringMethod(message_descriptor, cls) 1039 _AddMergeFromStringMethod(message_descriptor, cls) 1040 _AddIsInitializedMethod(message_descriptor, cls) 1041 _AddMergeFromMethod(cls) 1042 _AddWhichOneofMethod(message_descriptor, cls) 1043 1044def _AddPrivateHelperMethods(message_descriptor, cls): 1045 """Adds implementation of private helper methods to cls.""" 1046 1047 def Modified(self): 1048 """Sets the _cached_byte_size_dirty bit to true, 1049 and propagates this to our listener iff this was a state change. 1050 """ 1051 1052 # Note: Some callers check _cached_byte_size_dirty before calling 1053 # _Modified() as an extra optimization. So, if this method is ever 1054 # changed such that it does stuff even when _cached_byte_size_dirty is 1055 # already true, the callers need to be updated. 1056 if not self._cached_byte_size_dirty: 1057 self._cached_byte_size_dirty = True 1058 self._listener_for_children.dirty = True 1059 self._is_present_in_parent = True 1060 self._listener.Modified() 1061 1062 def _UpdateOneofState(self, field): 1063 """Sets field as the active field in its containing oneof. 1064 1065 Will also delete currently active field in the oneof, if it is different 1066 from the argument. Does not mark the message as modified. 1067 """ 1068 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1069 if other_field is not field: 1070 del self._fields[other_field] 1071 self._oneofs[field.containing_oneof] = field 1072 1073 cls._Modified = Modified 1074 cls.SetInParent = Modified 1075 cls._UpdateOneofState = _UpdateOneofState 1076 1077 1078class _Listener(object): 1079 1080 """MessageListener implementation that a parent message registers with its 1081 child message. 1082 1083 In order to support semantics like: 1084 1085 foo.bar.baz.qux = 23 1086 assert foo.HasField('bar') 1087 1088 ...child objects must have back references to their parents. 1089 This helper class is at the heart of this support. 1090 """ 1091 1092 def __init__(self, parent_message): 1093 """Args: 1094 parent_message: The message whose _Modified() method we should call when 1095 we receive Modified() messages. 1096 """ 1097 # This listener establishes a back reference from a child (contained) object 1098 # to its parent (containing) object. We make this a weak reference to avoid 1099 # creating cyclic garbage when the client finishes with the 'parent' object 1100 # in the tree. 1101 if isinstance(parent_message, weakref.ProxyType): 1102 self._parent_message_weakref = parent_message 1103 else: 1104 self._parent_message_weakref = weakref.proxy(parent_message) 1105 1106 # As an optimization, we also indicate directly on the listener whether 1107 # or not the parent message is dirty. This way we can avoid traversing 1108 # up the tree in the common case. 1109 self.dirty = False 1110 1111 def Modified(self): 1112 if self.dirty: 1113 return 1114 try: 1115 # Propagate the signal to our parents iff this is the first field set. 1116 self._parent_message_weakref._Modified() 1117 except ReferenceError: 1118 # We can get here if a client has kept a reference to a child object, 1119 # and is now setting a field on it, but the child's parent has been 1120 # garbage-collected. This is not an error. 1121 pass 1122 1123 1124class _OneofListener(_Listener): 1125 """Special listener implementation for setting composite oneof fields.""" 1126 1127 def __init__(self, parent_message, field): 1128 """Args: 1129 parent_message: The message whose _Modified() method we should call when 1130 we receive Modified() messages. 1131 field: The descriptor of the field being set in the parent message. 1132 """ 1133 super(_OneofListener, self).__init__(parent_message) 1134 self._field = field 1135 1136 def Modified(self): 1137 """Also updates the state of the containing oneof in the parent message.""" 1138 try: 1139 self._parent_message_weakref._UpdateOneofState(self._field) 1140 super(_OneofListener, self).Modified() 1141 except ReferenceError: 1142 pass 1143 1144 1145# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... 1146# TODO(robinson): Unify error handling of "unknown extension" crap. 1147# TODO(robinson): Support iteritems()-style iteration over all 1148# extensions with the "has" bits turned on? 1149class _ExtensionDict(object): 1150 1151 """Dict-like container for supporting an indexable "Extensions" 1152 field on proto instances. 1153 1154 Note that in all cases we expect extension handles to be 1155 FieldDescriptors. 1156 """ 1157 1158 def __init__(self, extended_message): 1159 """extended_message: Message instance for which we are the Extensions dict. 1160 """ 1161 1162 self._extended_message = extended_message 1163 1164 def __getitem__(self, extension_handle): 1165 """Returns the current value of the given extension handle.""" 1166 1167 _VerifyExtensionHandle(self._extended_message, extension_handle) 1168 1169 result = self._extended_message._fields.get(extension_handle) 1170 if result is not None: 1171 return result 1172 1173 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 1174 result = extension_handle._default_constructor(self._extended_message) 1175 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1176 result = extension_handle.message_type._concrete_class() 1177 try: 1178 result._SetListener(self._extended_message._listener_for_children) 1179 except ReferenceError: 1180 pass 1181 else: 1182 # Singular scalar -- just return the default without inserting into the 1183 # dict. 1184 return extension_handle.default_value 1185 1186 # Atomically check if another thread has preempted us and, if not, swap 1187 # in the new object we just created. If someone has preempted us, we 1188 # take that object and discard ours. 1189 # WARNING: We are relying on setdefault() being atomic. This is true 1190 # in CPython but we haven't investigated others. This warning appears 1191 # in several other locations in this file. 1192 result = self._extended_message._fields.setdefault( 1193 extension_handle, result) 1194 1195 return result 1196 1197 def __eq__(self, other): 1198 if not isinstance(other, self.__class__): 1199 return False 1200 1201 my_fields = self._extended_message.ListFields() 1202 other_fields = other._extended_message.ListFields() 1203 1204 # Get rid of non-extension fields. 1205 my_fields = [ field for field in my_fields if field.is_extension ] 1206 other_fields = [ field for field in other_fields if field.is_extension ] 1207 1208 return my_fields == other_fields 1209 1210 def __ne__(self, other): 1211 return not self == other 1212 1213 def __hash__(self): 1214 raise TypeError('unhashable object') 1215 1216 # Note that this is only meaningful for non-repeated, scalar extension 1217 # fields. Note also that we may have to call _Modified() when we do 1218 # successfully set a field this way, to set any necssary "has" bits in the 1219 # ancestors of the extended message. 1220 def __setitem__(self, extension_handle, value): 1221 """If extension_handle specifies a non-repeated, scalar extension 1222 field, sets the value of that field. 1223 """ 1224 1225 _VerifyExtensionHandle(self._extended_message, extension_handle) 1226 1227 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or 1228 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): 1229 raise TypeError( 1230 'Cannot assign to extension "%s" because it is a repeated or ' 1231 'composite type.' % extension_handle.full_name) 1232 1233 # It's slightly wasteful to lookup the type checker each time, 1234 # but we expect this to be a vanishingly uncommon case anyway. 1235 type_checker = type_checkers.GetTypeChecker( 1236 extension_handle) 1237 # pylint: disable=protected-access 1238 self._extended_message._fields[extension_handle] = ( 1239 type_checker.CheckValue(value)) 1240 self._extended_message._Modified() 1241 1242 def _FindExtensionByName(self, name): 1243 """Tries to find a known extension with the specified name. 1244 1245 Args: 1246 name: Extension full name. 1247 1248 Returns: 1249 Extension field descriptor. 1250 """ 1251 return self._extended_message._extensions_by_name.get(name, None) 1252