1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# http://code.google.com/p/protobuf/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53try: 54 from cStringIO import StringIO 55except ImportError: 56 from StringIO import StringIO 57import struct 58import weakref 59 60# We use "as" to avoid name collisions with variables. 61from google.protobuf.internal import containers 62from google.protobuf.internal import decoder 63from google.protobuf.internal import encoder 64from google.protobuf.internal import message_listener as message_listener_mod 65from google.protobuf.internal import type_checkers 66from google.protobuf.internal import wire_format 67from google.protobuf import descriptor as descriptor_mod 68from google.protobuf import message as message_mod 69from google.protobuf import text_format 70 71_FieldDescriptor = descriptor_mod.FieldDescriptor 72 73 74class GeneratedProtocolMessageType(type): 75 76 """Metaclass for protocol message classes created at runtime from Descriptors. 77 78 We add implementations for all methods described in the Message class. We 79 also create properties to allow getting/setting all fields in the protocol 80 message. Finally, we create slots to prevent users from accidentally 81 "setting" nonexistent fields in the protocol message, which then wouldn't get 82 serialized / deserialized properly. 83 84 The protocol compiler currently uses this metaclass to create protocol 85 message classes at runtime. Clients can also manually create their own 86 classes at runtime, as in this example: 87 88 mydescriptor = Descriptor(.....) 89 class MyProtoClass(Message): 90 __metaclass__ = GeneratedProtocolMessageType 91 DESCRIPTOR = mydescriptor 92 myproto_instance = MyProtoClass() 93 myproto.foo_field = 23 94 ... 95 """ 96 97 # Must be consistent with the protocol-compiler code in 98 # proto2/compiler/internal/generator.*. 99 _DESCRIPTOR_KEY = 'DESCRIPTOR' 100 101 def __new__(cls, name, bases, dictionary): 102 """Custom allocation for runtime-generated class types. 103 104 We override __new__ because this is apparently the only place 105 where we can meaningfully set __slots__ on the class we're creating(?). 106 (The interplay between metaclasses and slots is not very well-documented). 107 108 Args: 109 name: Name of the class (ignored, but required by the 110 metaclass protocol). 111 bases: Base classes of the class we're constructing. 112 (Should be message.Message). We ignore this field, but 113 it's required by the metaclass protocol 114 dictionary: The class dictionary of the class we're 115 constructing. dictionary[_DESCRIPTOR_KEY] must contain 116 a Descriptor object describing this protocol message 117 type. 118 119 Returns: 120 Newly-allocated class. 121 """ 122 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 123 _AddSlots(descriptor, dictionary) 124 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 125 superclass = super(GeneratedProtocolMessageType, cls) 126 return superclass.__new__(cls, name, bases, dictionary) 127 128 def __init__(cls, name, bases, dictionary): 129 """Here we perform the majority of our work on the class. 130 We add enum getters, an __init__ method, implementations 131 of all Message methods, and properties for all fields 132 in the protocol type. 133 134 Args: 135 name: Name of the class (ignored, but required by the 136 metaclass protocol). 137 bases: Base classes of the class we're constructing. 138 (Should be message.Message). We ignore this field, but 139 it's required by the metaclass protocol 140 dictionary: The class dictionary of the class we're 141 constructing. dictionary[_DESCRIPTOR_KEY] must contain 142 a Descriptor object describing this protocol message 143 type. 144 """ 145 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 146 147 cls._decoders_by_tag = {} 148 cls._extensions_by_name = {} 149 cls._extensions_by_number = {} 150 if (descriptor.has_options and 151 descriptor.GetOptions().message_set_wire_format): 152 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 153 decoder.MessageSetItemDecoder(cls._extensions_by_number)) 154 155 # We act as a "friend" class of the descriptor, setting 156 # its _concrete_class attribute the first time we use a 157 # given descriptor to initialize a concrete protocol message 158 # class. We also attach stuff to each FieldDescriptor for quick 159 # lookup later on. 160 concrete_class_attr_name = '_concrete_class' 161 if not hasattr(descriptor, concrete_class_attr_name): 162 setattr(descriptor, concrete_class_attr_name, cls) 163 for field in descriptor.fields: 164 _AttachFieldHelpers(cls, field) 165 166 _AddEnumValues(descriptor, cls) 167 _AddInitMethod(descriptor, cls) 168 _AddPropertiesForFields(descriptor, cls) 169 _AddPropertiesForExtensions(descriptor, cls) 170 _AddStaticMethods(cls) 171 _AddMessageMethods(descriptor, cls) 172 _AddPrivateHelperMethods(cls) 173 superclass = super(GeneratedProtocolMessageType, cls) 174 superclass.__init__(name, bases, dictionary) 175 176 177# Stateless helpers for GeneratedProtocolMessageType below. 178# Outside clients should not access these directly. 179# 180# I opted not to make any of these methods on the metaclass, to make it more 181# clear that I'm not really using any state there and to keep clients from 182# thinking that they have direct access to these construction helpers. 183 184 185def _PropertyName(proto_field_name): 186 """Returns the name of the public property attribute which 187 clients can use to get and (in some cases) set the value 188 of a protocol message field. 189 190 Args: 191 proto_field_name: The protocol message field name, exactly 192 as it appears (or would appear) in a .proto file. 193 """ 194 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 195 # nnorwitz makes my day by writing: 196 # """ 197 # FYI. See the keyword module in the stdlib. This could be as simple as: 198 # 199 # if keyword.iskeyword(proto_field_name): 200 # return proto_field_name + "_" 201 # return proto_field_name 202 # """ 203 # Kenton says: The above is a BAD IDEA. People rely on being able to use 204 # getattr() and setattr() to reflectively manipulate field values. If we 205 # rename the properties, then every such user has to also make sure to apply 206 # the same transformation. Note that currently if you name a field "yield", 207 # you can still access it just fine using getattr/setattr -- it's not even 208 # that cumbersome to do so. 209 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 210 # position. 211 return proto_field_name 212 213 214def _VerifyExtensionHandle(message, extension_handle): 215 """Verify that the given extension handle is valid.""" 216 217 if not isinstance(extension_handle, _FieldDescriptor): 218 raise KeyError('HasExtension() expects an extension handle, got: %s' % 219 extension_handle) 220 221 if not extension_handle.is_extension: 222 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) 223 224 if extension_handle.containing_type is not message.DESCRIPTOR: 225 raise KeyError('Extension "%s" extends message type "%s", but this ' 226 'message is of type "%s".' % 227 (extension_handle.full_name, 228 extension_handle.containing_type.full_name, 229 message.DESCRIPTOR.full_name)) 230 231 232def _AddSlots(message_descriptor, dictionary): 233 """Adds a __slots__ entry to dictionary, containing the names of all valid 234 attributes for this message type. 235 236 Args: 237 message_descriptor: A Descriptor instance describing this message type. 238 dictionary: Class dictionary to which we'll add a '__slots__' entry. 239 """ 240 dictionary['__slots__'] = ['_cached_byte_size', 241 '_cached_byte_size_dirty', 242 '_fields', 243 '_is_present_in_parent', 244 '_listener', 245 '_listener_for_children', 246 '__weakref__'] 247 248 249def _IsMessageSetExtension(field): 250 return (field.is_extension and 251 field.containing_type.has_options and 252 field.containing_type.GetOptions().message_set_wire_format and 253 field.type == _FieldDescriptor.TYPE_MESSAGE and 254 field.message_type == field.extension_scope and 255 field.label == _FieldDescriptor.LABEL_OPTIONAL) 256 257 258def _AttachFieldHelpers(cls, field_descriptor): 259 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 260 is_packed = (field_descriptor.has_options and 261 field_descriptor.GetOptions().packed) 262 263 if _IsMessageSetExtension(field_descriptor): 264 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 265 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 266 else: 267 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 268 field_descriptor.number, is_repeated, is_packed) 269 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 270 field_descriptor.number, is_repeated, is_packed) 271 272 field_descriptor._encoder = field_encoder 273 field_descriptor._sizer = sizer 274 field_descriptor._default_constructor = _DefaultValueConstructorForField( 275 field_descriptor) 276 277 def AddDecoder(wiretype, is_packed): 278 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 279 cls._decoders_by_tag[tag_bytes] = ( 280 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( 281 field_descriptor.number, is_repeated, is_packed, 282 field_descriptor, field_descriptor._default_constructor)) 283 284 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 285 False) 286 287 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 288 # To support wire compatibility of adding packed = true, add a decoder for 289 # packed values regardless of the field's options. 290 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 291 292 293def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 294 extension_dict = descriptor.extensions_by_name 295 for extension_name, extension_field in extension_dict.iteritems(): 296 assert extension_name not in dictionary 297 dictionary[extension_name] = extension_field 298 299 300def _AddEnumValues(descriptor, cls): 301 """Sets class-level attributes for all enum fields defined in this message. 302 303 Args: 304 descriptor: Descriptor object for this message type. 305 cls: Class we're constructing for this message type. 306 """ 307 for enum_type in descriptor.enum_types: 308 for enum_value in enum_type.values: 309 setattr(cls, enum_value.name, enum_value.number) 310 311 312def _DefaultValueConstructorForField(field): 313 """Returns a function which returns a default value for a field. 314 315 Args: 316 field: FieldDescriptor object for this field. 317 318 The returned function has one argument: 319 message: Message instance containing this field, or a weakref proxy 320 of same. 321 322 That function in turn returns a default value for this field. The default 323 value may refer back to |message| via a weak reference. 324 """ 325 326 if field.label == _FieldDescriptor.LABEL_REPEATED: 327 if field.default_value != []: 328 raise ValueError('Repeated field default value not empty list: %s' % ( 329 field.default_value)) 330 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 331 # We can't look at _concrete_class yet since it might not have 332 # been set. (Depends on order in which we initialize the classes). 333 message_type = field.message_type 334 def MakeRepeatedMessageDefault(message): 335 return containers.RepeatedCompositeFieldContainer( 336 message._listener_for_children, field.message_type) 337 return MakeRepeatedMessageDefault 338 else: 339 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) 340 def MakeRepeatedScalarDefault(message): 341 return containers.RepeatedScalarFieldContainer( 342 message._listener_for_children, type_checker) 343 return MakeRepeatedScalarDefault 344 345 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 346 # _concrete_class may not yet be initialized. 347 message_type = field.message_type 348 def MakeSubMessageDefault(message): 349 result = message_type._concrete_class() 350 result._SetListener(message._listener_for_children) 351 return result 352 return MakeSubMessageDefault 353 354 def MakeScalarDefault(message): 355 return field.default_value 356 return MakeScalarDefault 357 358 359def _AddInitMethod(message_descriptor, cls): 360 """Adds an __init__ method to cls.""" 361 fields = message_descriptor.fields 362 def init(self, **kwargs): 363 self._cached_byte_size = 0 364 self._cached_byte_size_dirty = False 365 self._fields = {} 366 self._is_present_in_parent = False 367 self._listener = message_listener_mod.NullMessageListener() 368 self._listener_for_children = _Listener(self) 369 for field_name, field_value in kwargs.iteritems(): 370 field = _GetFieldByName(message_descriptor, field_name) 371 if field is None: 372 raise TypeError("%s() got an unexpected keyword argument '%s'" % 373 (message_descriptor.name, field_name)) 374 if field.label == _FieldDescriptor.LABEL_REPEATED: 375 copy = field._default_constructor(self) 376 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 377 for val in field_value: 378 copy.add().MergeFrom(val) 379 else: # Scalar 380 copy.extend(field_value) 381 self._fields[field] = copy 382 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 383 copy = field._default_constructor(self) 384 copy.MergeFrom(field_value) 385 self._fields[field] = copy 386 else: 387 self._fields[field] = field_value 388 389 init.__module__ = None 390 init.__doc__ = None 391 cls.__init__ = init 392 393 394def _GetFieldByName(message_descriptor, field_name): 395 """Returns a field descriptor by field name. 396 397 Args: 398 message_descriptor: A Descriptor describing all fields in message. 399 field_name: The name of the field to retrieve. 400 Returns: 401 The field descriptor associated with the field name. 402 """ 403 try: 404 return message_descriptor.fields_by_name[field_name] 405 except KeyError: 406 raise ValueError('Protocol message has no "%s" field.' % field_name) 407 408 409def _AddPropertiesForFields(descriptor, cls): 410 """Adds properties for all fields in this protocol message type.""" 411 for field in descriptor.fields: 412 _AddPropertiesForField(field, cls) 413 414 if descriptor.is_extendable: 415 # _ExtensionDict is just an adaptor with no state so we allocate a new one 416 # every time it is accessed. 417 cls.Extensions = property(lambda self: _ExtensionDict(self)) 418 419 420def _AddPropertiesForField(field, cls): 421 """Adds a public property for a protocol message field. 422 Clients can use this property to get and (in the case 423 of non-repeated scalar fields) directly set the value 424 of a protocol message field. 425 426 Args: 427 field: A FieldDescriptor for this field. 428 cls: The class we're constructing. 429 """ 430 # Catch it if we add other types that we should 431 # handle specially here. 432 assert _FieldDescriptor.MAX_CPPTYPE == 10 433 434 constant_name = field.name.upper() + "_FIELD_NUMBER" 435 setattr(cls, constant_name, field.number) 436 437 if field.label == _FieldDescriptor.LABEL_REPEATED: 438 _AddPropertiesForRepeatedField(field, cls) 439 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 440 _AddPropertiesForNonRepeatedCompositeField(field, cls) 441 else: 442 _AddPropertiesForNonRepeatedScalarField(field, cls) 443 444 445def _AddPropertiesForRepeatedField(field, cls): 446 """Adds a public property for a "repeated" protocol message field. Clients 447 can use this property to get the value of the field, which will be either a 448 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see 449 below). 450 451 Note that when clients add values to these containers, we perform 452 type-checking in the case of repeated scalar fields, and we also set any 453 necessary "has" bits as a side-effect. 454 455 Args: 456 field: A FieldDescriptor for this field. 457 cls: The class we're constructing. 458 """ 459 proto_field_name = field.name 460 property_name = _PropertyName(proto_field_name) 461 462 def getter(self): 463 field_value = self._fields.get(field) 464 if field_value is None: 465 # Construct a new object to represent this field. 466 field_value = field._default_constructor(self) 467 468 # Atomically check if another thread has preempted us and, if not, swap 469 # in the new object we just created. If someone has preempted us, we 470 # take that object and discard ours. 471 # WARNING: We are relying on setdefault() being atomic. This is true 472 # in CPython but we haven't investigated others. This warning appears 473 # in several other locations in this file. 474 field_value = self._fields.setdefault(field, field_value) 475 return field_value 476 getter.__module__ = None 477 getter.__doc__ = 'Getter for %s.' % proto_field_name 478 479 # We define a setter just so we can throw an exception with a more 480 # helpful error message. 481 def setter(self, new_value): 482 raise AttributeError('Assignment not allowed to repeated field ' 483 '"%s" in protocol message object.' % proto_field_name) 484 485 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 486 setattr(cls, property_name, property(getter, setter, doc=doc)) 487 488 489def _AddPropertiesForNonRepeatedScalarField(field, cls): 490 """Adds a public property for a nonrepeated, scalar protocol message field. 491 Clients can use this property to get and directly set the value of the field. 492 Note that when the client sets the value of a field by using this property, 493 all necessary "has" bits are set as a side-effect, and we also perform 494 type-checking. 495 496 Args: 497 field: A FieldDescriptor for this field. 498 cls: The class we're constructing. 499 """ 500 proto_field_name = field.name 501 property_name = _PropertyName(proto_field_name) 502 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) 503 default_value = field.default_value 504 505 def getter(self): 506 return self._fields.get(field, default_value) 507 getter.__module__ = None 508 getter.__doc__ = 'Getter for %s.' % proto_field_name 509 def setter(self, new_value): 510 type_checker.CheckValue(new_value) 511 self._fields[field] = new_value 512 # Check _cached_byte_size_dirty inline to improve performance, since scalar 513 # setters are called frequently. 514 if not self._cached_byte_size_dirty: 515 self._Modified() 516 setter.__module__ = None 517 setter.__doc__ = 'Setter for %s.' % proto_field_name 518 519 # Add a property to encapsulate the getter/setter. 520 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 521 setattr(cls, property_name, property(getter, setter, doc=doc)) 522 523 524def _AddPropertiesForNonRepeatedCompositeField(field, cls): 525 """Adds a public property for a nonrepeated, composite protocol message field. 526 A composite field is a "group" or "message" field. 527 528 Clients can use this property to get the value of the field, but cannot 529 assign to the property directly. 530 531 Args: 532 field: A FieldDescriptor for this field. 533 cls: The class we're constructing. 534 """ 535 # TODO(robinson): Remove duplication with similar method 536 # for non-repeated scalars. 537 proto_field_name = field.name 538 property_name = _PropertyName(proto_field_name) 539 message_type = field.message_type 540 541 def getter(self): 542 field_value = self._fields.get(field) 543 if field_value is None: 544 # Construct a new object to represent this field. 545 field_value = message_type._concrete_class() 546 field_value._SetListener(self._listener_for_children) 547 548 # Atomically check if another thread has preempted us and, if not, swap 549 # in the new object we just created. If someone has preempted us, we 550 # take that object and discard ours. 551 # WARNING: We are relying on setdefault() being atomic. This is true 552 # in CPython but we haven't investigated others. This warning appears 553 # in several other locations in this file. 554 field_value = self._fields.setdefault(field, field_value) 555 return field_value 556 getter.__module__ = None 557 getter.__doc__ = 'Getter for %s.' % proto_field_name 558 559 # We define a setter just so we can throw an exception with a more 560 # helpful error message. 561 def setter(self, new_value): 562 raise AttributeError('Assignment not allowed to composite field ' 563 '"%s" in protocol message object.' % proto_field_name) 564 565 # Add a property to encapsulate the getter. 566 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 567 setattr(cls, property_name, property(getter, setter, doc=doc)) 568 569 570def _AddPropertiesForExtensions(descriptor, cls): 571 """Adds properties for all fields in this protocol message type.""" 572 extension_dict = descriptor.extensions_by_name 573 for extension_name, extension_field in extension_dict.iteritems(): 574 constant_name = extension_name.upper() + "_FIELD_NUMBER" 575 setattr(cls, constant_name, extension_field.number) 576 577 578def _AddStaticMethods(cls): 579 # TODO(robinson): This probably needs to be thread-safe(?) 580 def RegisterExtension(extension_handle): 581 extension_handle.containing_type = cls.DESCRIPTOR 582 _AttachFieldHelpers(cls, extension_handle) 583 584 # Try to insert our extension, failing if an extension with the same number 585 # already exists. 586 actual_handle = cls._extensions_by_number.setdefault( 587 extension_handle.number, extension_handle) 588 if actual_handle is not extension_handle: 589 raise AssertionError( 590 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 591 'field number %d.' % 592 (extension_handle.full_name, actual_handle.full_name, 593 cls.DESCRIPTOR.full_name, extension_handle.number)) 594 595 cls._extensions_by_name[extension_handle.full_name] = extension_handle 596 597 handle = extension_handle # avoid line wrapping 598 if _IsMessageSetExtension(handle): 599 # MessageSet extension. Also register under type name. 600 cls._extensions_by_name[ 601 extension_handle.message_type.full_name] = extension_handle 602 603 cls.RegisterExtension = staticmethod(RegisterExtension) 604 605 def FromString(s): 606 message = cls() 607 message.MergeFromString(s) 608 return message 609 cls.FromString = staticmethod(FromString) 610 611 612def _IsPresent(item): 613 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 614 value should be included in the list returned by ListFields().""" 615 616 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 617 return bool(item[1]) 618 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 619 return item[1]._is_present_in_parent 620 else: 621 return True 622 623 624def _AddListFieldsMethod(message_descriptor, cls): 625 """Helper for _AddMessageMethods().""" 626 627 def ListFields(self): 628 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] 629 all_fields.sort(key = lambda item: item[0].number) 630 return all_fields 631 632 cls.ListFields = ListFields 633 634 635def _AddHasFieldMethod(message_descriptor, cls): 636 """Helper for _AddMessageMethods().""" 637 638 singular_fields = {} 639 for field in message_descriptor.fields: 640 if field.label != _FieldDescriptor.LABEL_REPEATED: 641 singular_fields[field.name] = field 642 643 def HasField(self, field_name): 644 try: 645 field = singular_fields[field_name] 646 except KeyError: 647 raise ValueError( 648 'Protocol message has no singular "%s" field.' % field_name) 649 650 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 651 value = self._fields.get(field) 652 return value is not None and value._is_present_in_parent 653 else: 654 return field in self._fields 655 cls.HasField = HasField 656 657 658def _AddClearFieldMethod(message_descriptor, cls): 659 """Helper for _AddMessageMethods().""" 660 def ClearField(self, field_name): 661 try: 662 field = message_descriptor.fields_by_name[field_name] 663 except KeyError: 664 raise ValueError('Protocol message has no "%s" field.' % field_name) 665 666 if field in self._fields: 667 # Note: If the field is a sub-message, its listener will still point 668 # at us. That's fine, because the worst than can happen is that it 669 # will call _Modified() and invalidate our byte size. Big deal. 670 del self._fields[field] 671 672 # Always call _Modified() -- even if nothing was changed, this is 673 # a mutating method, and thus calling it should cause the field to become 674 # present in the parent message. 675 self._Modified() 676 677 cls.ClearField = ClearField 678 679 680def _AddClearExtensionMethod(cls): 681 """Helper for _AddMessageMethods().""" 682 def ClearExtension(self, extension_handle): 683 _VerifyExtensionHandle(self, extension_handle) 684 685 # Similar to ClearField(), above. 686 if extension_handle in self._fields: 687 del self._fields[extension_handle] 688 self._Modified() 689 cls.ClearExtension = ClearExtension 690 691 692def _AddClearMethod(message_descriptor, cls): 693 """Helper for _AddMessageMethods().""" 694 def Clear(self): 695 # Clear fields. 696 self._fields = {} 697 self._Modified() 698 cls.Clear = Clear 699 700 701def _AddHasExtensionMethod(cls): 702 """Helper for _AddMessageMethods().""" 703 def HasExtension(self, extension_handle): 704 _VerifyExtensionHandle(self, extension_handle) 705 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 706 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 707 708 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 709 value = self._fields.get(extension_handle) 710 return value is not None and value._is_present_in_parent 711 else: 712 return extension_handle in self._fields 713 cls.HasExtension = HasExtension 714 715 716def _AddEqualsMethod(message_descriptor, cls): 717 """Helper for _AddMessageMethods().""" 718 def __eq__(self, other): 719 if (not isinstance(other, message_mod.Message) or 720 other.DESCRIPTOR != self.DESCRIPTOR): 721 return False 722 723 if self is other: 724 return True 725 726 return self.ListFields() == other.ListFields() 727 728 cls.__eq__ = __eq__ 729 730 731def _AddStrMethod(message_descriptor, cls): 732 """Helper for _AddMessageMethods().""" 733 def __str__(self): 734 return text_format.MessageToString(self) 735 cls.__str__ = __str__ 736 737 738def _AddSetListenerMethod(cls): 739 """Helper for _AddMessageMethods().""" 740 def SetListener(self, listener): 741 if listener is None: 742 self._listener = message_listener_mod.NullMessageListener() 743 else: 744 self._listener = listener 745 cls._SetListener = SetListener 746 747 748def _BytesForNonRepeatedElement(value, field_number, field_type): 749 """Returns the number of bytes needed to serialize a non-repeated element. 750 The returned byte count includes space for tag information and any 751 other additional space associated with serializing value. 752 753 Args: 754 value: Value we're serializing. 755 field_number: Field number of this value. (Since the field number 756 is stored as part of a varint-encoded tag, this has an impact 757 on the total bytes required to serialize the value). 758 field_type: The type of the field. One of the TYPE_* constants 759 within FieldDescriptor. 760 """ 761 try: 762 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 763 return fn(field_number, value) 764 except KeyError: 765 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 766 767 768def _AddByteSizeMethod(message_descriptor, cls): 769 """Helper for _AddMessageMethods().""" 770 771 def ByteSize(self): 772 if not self._cached_byte_size_dirty: 773 return self._cached_byte_size 774 775 size = 0 776 for field_descriptor, field_value in self.ListFields(): 777 size += field_descriptor._sizer(field_value) 778 779 self._cached_byte_size = size 780 self._cached_byte_size_dirty = False 781 self._listener_for_children.dirty = False 782 return size 783 784 cls.ByteSize = ByteSize 785 786 787def _AddSerializeToStringMethod(message_descriptor, cls): 788 """Helper for _AddMessageMethods().""" 789 790 def SerializeToString(self): 791 # Check if the message has all of its required fields set. 792 errors = [] 793 if not self.IsInitialized(): 794 raise message_mod.EncodeError( 795 'Message is missing required fields: ' + 796 ','.join(self.FindInitializationErrors())) 797 return self.SerializePartialToString() 798 cls.SerializeToString = SerializeToString 799 800 801def _AddSerializePartialToStringMethod(message_descriptor, cls): 802 """Helper for _AddMessageMethods().""" 803 804 def SerializePartialToString(self): 805 out = StringIO() 806 self._InternalSerialize(out.write) 807 return out.getvalue() 808 cls.SerializePartialToString = SerializePartialToString 809 810 def InternalSerialize(self, write_bytes): 811 for field_descriptor, field_value in self.ListFields(): 812 field_descriptor._encoder(write_bytes, field_value) 813 cls._InternalSerialize = InternalSerialize 814 815 816def _AddMergeFromStringMethod(message_descriptor, cls): 817 """Helper for _AddMessageMethods().""" 818 def MergeFromString(self, serialized): 819 length = len(serialized) 820 try: 821 if self._InternalParse(serialized, 0, length) != length: 822 # The only reason _InternalParse would return early is if it 823 # encountered an end-group tag. 824 raise message_mod.DecodeError('Unexpected end-group tag.') 825 except IndexError: 826 raise message_mod.DecodeError('Truncated message.') 827 except struct.error, e: 828 raise message_mod.DecodeError(e) 829 return length # Return this for legacy reasons. 830 cls.MergeFromString = MergeFromString 831 832 local_ReadTag = decoder.ReadTag 833 local_SkipField = decoder.SkipField 834 decoders_by_tag = cls._decoders_by_tag 835 836 def InternalParse(self, buffer, pos, end): 837 self._Modified() 838 field_dict = self._fields 839 while pos != end: 840 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 841 field_decoder = decoders_by_tag.get(tag_bytes) 842 if field_decoder is None: 843 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 844 if new_pos == -1: 845 return pos 846 pos = new_pos 847 else: 848 pos = field_decoder(buffer, new_pos, end, self, field_dict) 849 return pos 850 cls._InternalParse = InternalParse 851 852 853def _AddIsInitializedMethod(message_descriptor, cls): 854 """Adds the IsInitialized and FindInitializationError methods to the 855 protocol message class.""" 856 857 required_fields = [field for field in message_descriptor.fields 858 if field.label == _FieldDescriptor.LABEL_REQUIRED] 859 860 def IsInitialized(self, errors=None): 861 """Checks if all required fields of a message are set. 862 863 Args: 864 errors: A list which, if provided, will be populated with the field 865 paths of all missing required fields. 866 867 Returns: 868 True iff the specified message has all required fields set. 869 """ 870 871 # Performance is critical so we avoid HasField() and ListFields(). 872 873 for field in required_fields: 874 if (field not in self._fields or 875 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 876 not self._fields[field]._is_present_in_parent)): 877 if errors is not None: 878 errors.extend(self.FindInitializationErrors()) 879 return False 880 881 for field, value in self._fields.iteritems(): 882 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 883 if field.label == _FieldDescriptor.LABEL_REPEATED: 884 for element in value: 885 if not element.IsInitialized(): 886 if errors is not None: 887 errors.extend(self.FindInitializationErrors()) 888 return False 889 elif value._is_present_in_parent and not value.IsInitialized(): 890 if errors is not None: 891 errors.extend(self.FindInitializationErrors()) 892 return False 893 894 return True 895 896 cls.IsInitialized = IsInitialized 897 898 def FindInitializationErrors(self): 899 """Finds required fields which are not initialized. 900 901 Returns: 902 A list of strings. Each string is a path to an uninitialized field from 903 the top-level message, e.g. "foo.bar[5].baz". 904 """ 905 906 errors = [] # simplify things 907 908 for field in required_fields: 909 if not self.HasField(field.name): 910 errors.append(field.name) 911 912 for field, value in self.ListFields(): 913 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 914 if field.is_extension: 915 name = "(%s)" % field.full_name 916 else: 917 name = field.name 918 919 if field.label == _FieldDescriptor.LABEL_REPEATED: 920 for i in xrange(len(value)): 921 element = value[i] 922 prefix = "%s[%d]." % (name, i) 923 sub_errors = element.FindInitializationErrors() 924 errors += [ prefix + error for error in sub_errors ] 925 else: 926 prefix = name + "." 927 sub_errors = value.FindInitializationErrors() 928 errors += [ prefix + error for error in sub_errors ] 929 930 return errors 931 932 cls.FindInitializationErrors = FindInitializationErrors 933 934 935def _AddMergeFromMethod(cls): 936 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 937 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 938 939 def MergeFrom(self, msg): 940 assert msg is not self 941 self._Modified() 942 943 fields = self._fields 944 945 for field, value in msg._fields.iteritems(): 946 if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE: 947 field_value = fields.get(field) 948 if field_value is None: 949 # Construct a new object to represent this field. 950 field_value = field._default_constructor(self) 951 fields[field] = field_value 952 field_value.MergeFrom(value) 953 else: 954 self._fields[field] = value 955 cls.MergeFrom = MergeFrom 956 957 958def _AddMessageMethods(message_descriptor, cls): 959 """Adds implementations of all Message methods to cls.""" 960 _AddListFieldsMethod(message_descriptor, cls) 961 _AddHasFieldMethod(message_descriptor, cls) 962 _AddClearFieldMethod(message_descriptor, cls) 963 if message_descriptor.is_extendable: 964 _AddClearExtensionMethod(cls) 965 _AddHasExtensionMethod(cls) 966 _AddClearMethod(message_descriptor, cls) 967 _AddEqualsMethod(message_descriptor, cls) 968 _AddStrMethod(message_descriptor, cls) 969 _AddSetListenerMethod(cls) 970 _AddByteSizeMethod(message_descriptor, cls) 971 _AddSerializeToStringMethod(message_descriptor, cls) 972 _AddSerializePartialToStringMethod(message_descriptor, cls) 973 _AddMergeFromStringMethod(message_descriptor, cls) 974 _AddIsInitializedMethod(message_descriptor, cls) 975 _AddMergeFromMethod(cls) 976 977 978def _AddPrivateHelperMethods(cls): 979 """Adds implementation of private helper methods to cls.""" 980 981 def Modified(self): 982 """Sets the _cached_byte_size_dirty bit to true, 983 and propagates this to our listener iff this was a state change. 984 """ 985 986 # Note: Some callers check _cached_byte_size_dirty before calling 987 # _Modified() as an extra optimization. So, if this method is ever 988 # changed such that it does stuff even when _cached_byte_size_dirty is 989 # already true, the callers need to be updated. 990 if not self._cached_byte_size_dirty: 991 self._cached_byte_size_dirty = True 992 self._listener_for_children.dirty = True 993 self._is_present_in_parent = True 994 self._listener.Modified() 995 996 cls._Modified = Modified 997 cls.SetInParent = Modified 998 999 1000class _Listener(object): 1001 1002 """MessageListener implementation that a parent message registers with its 1003 child message. 1004 1005 In order to support semantics like: 1006 1007 foo.bar.baz.qux = 23 1008 assert foo.HasField('bar') 1009 1010 ...child objects must have back references to their parents. 1011 This helper class is at the heart of this support. 1012 """ 1013 1014 def __init__(self, parent_message): 1015 """Args: 1016 parent_message: The message whose _Modified() method we should call when 1017 we receive Modified() messages. 1018 """ 1019 # This listener establishes a back reference from a child (contained) object 1020 # to its parent (containing) object. We make this a weak reference to avoid 1021 # creating cyclic garbage when the client finishes with the 'parent' object 1022 # in the tree. 1023 if isinstance(parent_message, weakref.ProxyType): 1024 self._parent_message_weakref = parent_message 1025 else: 1026 self._parent_message_weakref = weakref.proxy(parent_message) 1027 1028 # As an optimization, we also indicate directly on the listener whether 1029 # or not the parent message is dirty. This way we can avoid traversing 1030 # up the tree in the common case. 1031 self.dirty = False 1032 1033 def Modified(self): 1034 if self.dirty: 1035 return 1036 try: 1037 # Propagate the signal to our parents iff this is the first field set. 1038 self._parent_message_weakref._Modified() 1039 except ReferenceError: 1040 # We can get here if a client has kept a reference to a child object, 1041 # and is now setting a field on it, but the child's parent has been 1042 # garbage-collected. This is not an error. 1043 pass 1044 1045 1046# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... 1047# TODO(robinson): Unify error handling of "unknown extension" crap. 1048# TODO(robinson): Support iteritems()-style iteration over all 1049# extensions with the "has" bits turned on? 1050class _ExtensionDict(object): 1051 1052 """Dict-like container for supporting an indexable "Extensions" 1053 field on proto instances. 1054 1055 Note that in all cases we expect extension handles to be 1056 FieldDescriptors. 1057 """ 1058 1059 def __init__(self, extended_message): 1060 """extended_message: Message instance for which we are the Extensions dict. 1061 """ 1062 1063 self._extended_message = extended_message 1064 1065 def __getitem__(self, extension_handle): 1066 """Returns the current value of the given extension handle.""" 1067 1068 _VerifyExtensionHandle(self._extended_message, extension_handle) 1069 1070 result = self._extended_message._fields.get(extension_handle) 1071 if result is not None: 1072 return result 1073 1074 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 1075 result = extension_handle._default_constructor(self._extended_message) 1076 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1077 result = extension_handle.message_type._concrete_class() 1078 try: 1079 result._SetListener(self._extended_message._listener_for_children) 1080 except ReferenceError: 1081 pass 1082 else: 1083 # Singular scalar -- just return the default without inserting into the 1084 # dict. 1085 return extension_handle.default_value 1086 1087 # Atomically check if another thread has preempted us and, if not, swap 1088 # in the new object we just created. If someone has preempted us, we 1089 # take that object and discard ours. 1090 # WARNING: We are relying on setdefault() being atomic. This is true 1091 # in CPython but we haven't investigated others. This warning appears 1092 # in several other locations in this file. 1093 result = self._extended_message._fields.setdefault( 1094 extension_handle, result) 1095 1096 return result 1097 1098 def __eq__(self, other): 1099 if not isinstance(other, self.__class__): 1100 return False 1101 1102 my_fields = self._extended_message.ListFields() 1103 other_fields = other._extended_message.ListFields() 1104 1105 # Get rid of non-extension fields. 1106 my_fields = [ field for field in my_fields if field.is_extension ] 1107 other_fields = [ field for field in other_fields if field.is_extension ] 1108 1109 return my_fields == other_fields 1110 1111 def __ne__(self, other): 1112 return not self == other 1113 1114 # Note that this is only meaningful for non-repeated, scalar extension 1115 # fields. Note also that we may have to call _Modified() when we do 1116 # successfully set a field this way, to set any necssary "has" bits in the 1117 # ancestors of the extended message. 1118 def __setitem__(self, extension_handle, value): 1119 """If extension_handle specifies a non-repeated, scalar extension 1120 field, sets the value of that field. 1121 """ 1122 1123 _VerifyExtensionHandle(self._extended_message, extension_handle) 1124 1125 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or 1126 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): 1127 raise TypeError( 1128 'Cannot assign to extension "%s" because it is a repeated or ' 1129 'composite type.' % extension_handle.full_name) 1130 1131 # It's slightly wasteful to lookup the type checker each time, 1132 # but we expect this to be a vanishingly uncommon case anyway. 1133 type_checker = type_checkers.GetTypeChecker( 1134 extension_handle.cpp_type, extension_handle.type) 1135 type_checker.CheckValue(value) 1136 self._extended_message._fields[extension_handle] = value 1137 self._extended_message._Modified() 1138 1139 def _FindExtensionByName(self, name): 1140 """Tries to find a known extension with the specified name. 1141 1142 Args: 1143 name: Extension full name. 1144 1145 Returns: 1146 Extension field descriptor. 1147 """ 1148 return self._extended_message._extensions_by_name.get(name, None) 1149