1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# http://code.google.com/p/protobuf/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53try:
54  from cStringIO import StringIO
55except ImportError:
56  from StringIO import StringIO
57import struct
58import weakref
59
60# We use "as" to avoid name collisions with variables.
61from google.protobuf.internal import containers
62from google.protobuf.internal import decoder
63from google.protobuf.internal import encoder
64from google.protobuf.internal import message_listener as message_listener_mod
65from google.protobuf.internal import type_checkers
66from google.protobuf.internal import wire_format
67from google.protobuf import descriptor as descriptor_mod
68from google.protobuf import message as message_mod
69from google.protobuf import text_format
70
71_FieldDescriptor = descriptor_mod.FieldDescriptor
72
73
74class GeneratedProtocolMessageType(type):
75
76  """Metaclass for protocol message classes created at runtime from Descriptors.
77
78  We add implementations for all methods described in the Message class.  We
79  also create properties to allow getting/setting all fields in the protocol
80  message.  Finally, we create slots to prevent users from accidentally
81  "setting" nonexistent fields in the protocol message, which then wouldn't get
82  serialized / deserialized properly.
83
84  The protocol compiler currently uses this metaclass to create protocol
85  message classes at runtime.  Clients can also manually create their own
86  classes at runtime, as in this example:
87
88  mydescriptor = Descriptor(.....)
89  class MyProtoClass(Message):
90    __metaclass__ = GeneratedProtocolMessageType
91    DESCRIPTOR = mydescriptor
92  myproto_instance = MyProtoClass()
93  myproto.foo_field = 23
94  ...
95  """
96
97  # Must be consistent with the protocol-compiler code in
98  # proto2/compiler/internal/generator.*.
99  _DESCRIPTOR_KEY = 'DESCRIPTOR'
100
101  def __new__(cls, name, bases, dictionary):
102    """Custom allocation for runtime-generated class types.
103
104    We override __new__ because this is apparently the only place
105    where we can meaningfully set __slots__ on the class we're creating(?).
106    (The interplay between metaclasses and slots is not very well-documented).
107
108    Args:
109      name: Name of the class (ignored, but required by the
110        metaclass protocol).
111      bases: Base classes of the class we're constructing.
112        (Should be message.Message).  We ignore this field, but
113        it's required by the metaclass protocol
114      dictionary: The class dictionary of the class we're
115        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
116        a Descriptor object describing this protocol message
117        type.
118
119    Returns:
120      Newly-allocated class.
121    """
122    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
123    _AddSlots(descriptor, dictionary)
124    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
125    superclass = super(GeneratedProtocolMessageType, cls)
126    return superclass.__new__(cls, name, bases, dictionary)
127
128  def __init__(cls, name, bases, dictionary):
129    """Here we perform the majority of our work on the class.
130    We add enum getters, an __init__ method, implementations
131    of all Message methods, and properties for all fields
132    in the protocol type.
133
134    Args:
135      name: Name of the class (ignored, but required by the
136        metaclass protocol).
137      bases: Base classes of the class we're constructing.
138        (Should be message.Message).  We ignore this field, but
139        it's required by the metaclass protocol
140      dictionary: The class dictionary of the class we're
141        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
142        a Descriptor object describing this protocol message
143        type.
144    """
145    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
146
147    cls._decoders_by_tag = {}
148    cls._extensions_by_name = {}
149    cls._extensions_by_number = {}
150    if (descriptor.has_options and
151        descriptor.GetOptions().message_set_wire_format):
152      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
153          decoder.MessageSetItemDecoder(cls._extensions_by_number))
154
155    # We act as a "friend" class of the descriptor, setting
156    # its _concrete_class attribute the first time we use a
157    # given descriptor to initialize a concrete protocol message
158    # class.  We also attach stuff to each FieldDescriptor for quick
159    # lookup later on.
160    concrete_class_attr_name = '_concrete_class'
161    if not hasattr(descriptor, concrete_class_attr_name):
162      setattr(descriptor, concrete_class_attr_name, cls)
163      for field in descriptor.fields:
164        _AttachFieldHelpers(cls, field)
165
166    _AddEnumValues(descriptor, cls)
167    _AddInitMethod(descriptor, cls)
168    _AddPropertiesForFields(descriptor, cls)
169    _AddPropertiesForExtensions(descriptor, cls)
170    _AddStaticMethods(cls)
171    _AddMessageMethods(descriptor, cls)
172    _AddPrivateHelperMethods(cls)
173    superclass = super(GeneratedProtocolMessageType, cls)
174    superclass.__init__(name, bases, dictionary)
175
176
177# Stateless helpers for GeneratedProtocolMessageType below.
178# Outside clients should not access these directly.
179#
180# I opted not to make any of these methods on the metaclass, to make it more
181# clear that I'm not really using any state there and to keep clients from
182# thinking that they have direct access to these construction helpers.
183
184
185def _PropertyName(proto_field_name):
186  """Returns the name of the public property attribute which
187  clients can use to get and (in some cases) set the value
188  of a protocol message field.
189
190  Args:
191    proto_field_name: The protocol message field name, exactly
192      as it appears (or would appear) in a .proto file.
193  """
194  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
195  # nnorwitz makes my day by writing:
196  # """
197  # FYI.  See the keyword module in the stdlib. This could be as simple as:
198  #
199  # if keyword.iskeyword(proto_field_name):
200  #   return proto_field_name + "_"
201  # return proto_field_name
202  # """
203  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
204  #   getattr() and setattr() to reflectively manipulate field values.  If we
205  #   rename the properties, then every such user has to also make sure to apply
206  #   the same transformation.  Note that currently if you name a field "yield",
207  #   you can still access it just fine using getattr/setattr -- it's not even
208  #   that cumbersome to do so.
209  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
210  #   position.
211  return proto_field_name
212
213
214def _VerifyExtensionHandle(message, extension_handle):
215  """Verify that the given extension handle is valid."""
216
217  if not isinstance(extension_handle, _FieldDescriptor):
218    raise KeyError('HasExtension() expects an extension handle, got: %s' %
219                   extension_handle)
220
221  if not extension_handle.is_extension:
222    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
223
224  if extension_handle.containing_type is not message.DESCRIPTOR:
225    raise KeyError('Extension "%s" extends message type "%s", but this '
226                   'message is of type "%s".' %
227                   (extension_handle.full_name,
228                    extension_handle.containing_type.full_name,
229                    message.DESCRIPTOR.full_name))
230
231
232def _AddSlots(message_descriptor, dictionary):
233  """Adds a __slots__ entry to dictionary, containing the names of all valid
234  attributes for this message type.
235
236  Args:
237    message_descriptor: A Descriptor instance describing this message type.
238    dictionary: Class dictionary to which we'll add a '__slots__' entry.
239  """
240  dictionary['__slots__'] = ['_cached_byte_size',
241                             '_cached_byte_size_dirty',
242                             '_fields',
243                             '_is_present_in_parent',
244                             '_listener',
245                             '_listener_for_children',
246                             '__weakref__']
247
248
249def _IsMessageSetExtension(field):
250  return (field.is_extension and
251          field.containing_type.has_options and
252          field.containing_type.GetOptions().message_set_wire_format and
253          field.type == _FieldDescriptor.TYPE_MESSAGE and
254          field.message_type == field.extension_scope and
255          field.label == _FieldDescriptor.LABEL_OPTIONAL)
256
257
258def _AttachFieldHelpers(cls, field_descriptor):
259  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
260  is_packed = (field_descriptor.has_options and
261               field_descriptor.GetOptions().packed)
262
263  if _IsMessageSetExtension(field_descriptor):
264    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
265    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
266  else:
267    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
268        field_descriptor.number, is_repeated, is_packed)
269    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
270        field_descriptor.number, is_repeated, is_packed)
271
272  field_descriptor._encoder = field_encoder
273  field_descriptor._sizer = sizer
274  field_descriptor._default_constructor = _DefaultValueConstructorForField(
275      field_descriptor)
276
277  def AddDecoder(wiretype, is_packed):
278    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
279    cls._decoders_by_tag[tag_bytes] = (
280        type_checkers.TYPE_TO_DECODER[field_descriptor.type](
281            field_descriptor.number, is_repeated, is_packed,
282            field_descriptor, field_descriptor._default_constructor))
283
284  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
285             False)
286
287  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
288    # To support wire compatibility of adding packed = true, add a decoder for
289    # packed values regardless of the field's options.
290    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
291
292
293def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
294  extension_dict = descriptor.extensions_by_name
295  for extension_name, extension_field in extension_dict.iteritems():
296    assert extension_name not in dictionary
297    dictionary[extension_name] = extension_field
298
299
300def _AddEnumValues(descriptor, cls):
301  """Sets class-level attributes for all enum fields defined in this message.
302
303  Args:
304    descriptor: Descriptor object for this message type.
305    cls: Class we're constructing for this message type.
306  """
307  for enum_type in descriptor.enum_types:
308    for enum_value in enum_type.values:
309      setattr(cls, enum_value.name, enum_value.number)
310
311
312def _DefaultValueConstructorForField(field):
313  """Returns a function which returns a default value for a field.
314
315  Args:
316    field: FieldDescriptor object for this field.
317
318  The returned function has one argument:
319    message: Message instance containing this field, or a weakref proxy
320      of same.
321
322  That function in turn returns a default value for this field.  The default
323    value may refer back to |message| via a weak reference.
324  """
325
326  if field.label == _FieldDescriptor.LABEL_REPEATED:
327    if field.default_value != []:
328      raise ValueError('Repeated field default value not empty list: %s' % (
329          field.default_value))
330    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
331      # We can't look at _concrete_class yet since it might not have
332      # been set.  (Depends on order in which we initialize the classes).
333      message_type = field.message_type
334      def MakeRepeatedMessageDefault(message):
335        return containers.RepeatedCompositeFieldContainer(
336            message._listener_for_children, field.message_type)
337      return MakeRepeatedMessageDefault
338    else:
339      type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
340      def MakeRepeatedScalarDefault(message):
341        return containers.RepeatedScalarFieldContainer(
342            message._listener_for_children, type_checker)
343      return MakeRepeatedScalarDefault
344
345  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
346    # _concrete_class may not yet be initialized.
347    message_type = field.message_type
348    def MakeSubMessageDefault(message):
349      result = message_type._concrete_class()
350      result._SetListener(message._listener_for_children)
351      return result
352    return MakeSubMessageDefault
353
354  def MakeScalarDefault(message):
355    return field.default_value
356  return MakeScalarDefault
357
358
359def _AddInitMethod(message_descriptor, cls):
360  """Adds an __init__ method to cls."""
361  fields = message_descriptor.fields
362  def init(self, **kwargs):
363    self._cached_byte_size = 0
364    self._cached_byte_size_dirty = False
365    self._fields = {}
366    self._is_present_in_parent = False
367    self._listener = message_listener_mod.NullMessageListener()
368    self._listener_for_children = _Listener(self)
369    for field_name, field_value in kwargs.iteritems():
370      field = _GetFieldByName(message_descriptor, field_name)
371      if field is None:
372        raise TypeError("%s() got an unexpected keyword argument '%s'" %
373                        (message_descriptor.name, field_name))
374      if field.label == _FieldDescriptor.LABEL_REPEATED:
375        copy = field._default_constructor(self)
376        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
377          for val in field_value:
378            copy.add().MergeFrom(val)
379        else:  # Scalar
380          copy.extend(field_value)
381        self._fields[field] = copy
382      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
383        copy = field._default_constructor(self)
384        copy.MergeFrom(field_value)
385        self._fields[field] = copy
386      else:
387        self._fields[field] = field_value
388
389  init.__module__ = None
390  init.__doc__ = None
391  cls.__init__ = init
392
393
394def _GetFieldByName(message_descriptor, field_name):
395  """Returns a field descriptor by field name.
396
397  Args:
398    message_descriptor: A Descriptor describing all fields in message.
399    field_name: The name of the field to retrieve.
400  Returns:
401    The field descriptor associated with the field name.
402  """
403  try:
404    return message_descriptor.fields_by_name[field_name]
405  except KeyError:
406    raise ValueError('Protocol message has no "%s" field.' % field_name)
407
408
409def _AddPropertiesForFields(descriptor, cls):
410  """Adds properties for all fields in this protocol message type."""
411  for field in descriptor.fields:
412    _AddPropertiesForField(field, cls)
413
414  if descriptor.is_extendable:
415    # _ExtensionDict is just an adaptor with no state so we allocate a new one
416    # every time it is accessed.
417    cls.Extensions = property(lambda self: _ExtensionDict(self))
418
419
420def _AddPropertiesForField(field, cls):
421  """Adds a public property for a protocol message field.
422  Clients can use this property to get and (in the case
423  of non-repeated scalar fields) directly set the value
424  of a protocol message field.
425
426  Args:
427    field: A FieldDescriptor for this field.
428    cls: The class we're constructing.
429  """
430  # Catch it if we add other types that we should
431  # handle specially here.
432  assert _FieldDescriptor.MAX_CPPTYPE == 10
433
434  constant_name = field.name.upper() + "_FIELD_NUMBER"
435  setattr(cls, constant_name, field.number)
436
437  if field.label == _FieldDescriptor.LABEL_REPEATED:
438    _AddPropertiesForRepeatedField(field, cls)
439  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
440    _AddPropertiesForNonRepeatedCompositeField(field, cls)
441  else:
442    _AddPropertiesForNonRepeatedScalarField(field, cls)
443
444
445def _AddPropertiesForRepeatedField(field, cls):
446  """Adds a public property for a "repeated" protocol message field.  Clients
447  can use this property to get the value of the field, which will be either a
448  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
449  below).
450
451  Note that when clients add values to these containers, we perform
452  type-checking in the case of repeated scalar fields, and we also set any
453  necessary "has" bits as a side-effect.
454
455  Args:
456    field: A FieldDescriptor for this field.
457    cls: The class we're constructing.
458  """
459  proto_field_name = field.name
460  property_name = _PropertyName(proto_field_name)
461
462  def getter(self):
463    field_value = self._fields.get(field)
464    if field_value is None:
465      # Construct a new object to represent this field.
466      field_value = field._default_constructor(self)
467
468      # Atomically check if another thread has preempted us and, if not, swap
469      # in the new object we just created.  If someone has preempted us, we
470      # take that object and discard ours.
471      # WARNING:  We are relying on setdefault() being atomic.  This is true
472      #   in CPython but we haven't investigated others.  This warning appears
473      #   in several other locations in this file.
474      field_value = self._fields.setdefault(field, field_value)
475    return field_value
476  getter.__module__ = None
477  getter.__doc__ = 'Getter for %s.' % proto_field_name
478
479  # We define a setter just so we can throw an exception with a more
480  # helpful error message.
481  def setter(self, new_value):
482    raise AttributeError('Assignment not allowed to repeated field '
483                         '"%s" in protocol message object.' % proto_field_name)
484
485  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
486  setattr(cls, property_name, property(getter, setter, doc=doc))
487
488
489def _AddPropertiesForNonRepeatedScalarField(field, cls):
490  """Adds a public property for a nonrepeated, scalar protocol message field.
491  Clients can use this property to get and directly set the value of the field.
492  Note that when the client sets the value of a field by using this property,
493  all necessary "has" bits are set as a side-effect, and we also perform
494  type-checking.
495
496  Args:
497    field: A FieldDescriptor for this field.
498    cls: The class we're constructing.
499  """
500  proto_field_name = field.name
501  property_name = _PropertyName(proto_field_name)
502  type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
503  default_value = field.default_value
504
505  def getter(self):
506    return self._fields.get(field, default_value)
507  getter.__module__ = None
508  getter.__doc__ = 'Getter for %s.' % proto_field_name
509  def setter(self, new_value):
510    type_checker.CheckValue(new_value)
511    self._fields[field] = new_value
512    # Check _cached_byte_size_dirty inline to improve performance, since scalar
513    # setters are called frequently.
514    if not self._cached_byte_size_dirty:
515      self._Modified()
516  setter.__module__ = None
517  setter.__doc__ = 'Setter for %s.' % proto_field_name
518
519  # Add a property to encapsulate the getter/setter.
520  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
521  setattr(cls, property_name, property(getter, setter, doc=doc))
522
523
524def _AddPropertiesForNonRepeatedCompositeField(field, cls):
525  """Adds a public property for a nonrepeated, composite protocol message field.
526  A composite field is a "group" or "message" field.
527
528  Clients can use this property to get the value of the field, but cannot
529  assign to the property directly.
530
531  Args:
532    field: A FieldDescriptor for this field.
533    cls: The class we're constructing.
534  """
535  # TODO(robinson): Remove duplication with similar method
536  # for non-repeated scalars.
537  proto_field_name = field.name
538  property_name = _PropertyName(proto_field_name)
539  message_type = field.message_type
540
541  def getter(self):
542    field_value = self._fields.get(field)
543    if field_value is None:
544      # Construct a new object to represent this field.
545      field_value = message_type._concrete_class()
546      field_value._SetListener(self._listener_for_children)
547
548      # Atomically check if another thread has preempted us and, if not, swap
549      # in the new object we just created.  If someone has preempted us, we
550      # take that object and discard ours.
551      # WARNING:  We are relying on setdefault() being atomic.  This is true
552      #   in CPython but we haven't investigated others.  This warning appears
553      #   in several other locations in this file.
554      field_value = self._fields.setdefault(field, field_value)
555    return field_value
556  getter.__module__ = None
557  getter.__doc__ = 'Getter for %s.' % proto_field_name
558
559  # We define a setter just so we can throw an exception with a more
560  # helpful error message.
561  def setter(self, new_value):
562    raise AttributeError('Assignment not allowed to composite field '
563                         '"%s" in protocol message object.' % proto_field_name)
564
565  # Add a property to encapsulate the getter.
566  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
567  setattr(cls, property_name, property(getter, setter, doc=doc))
568
569
570def _AddPropertiesForExtensions(descriptor, cls):
571  """Adds properties for all fields in this protocol message type."""
572  extension_dict = descriptor.extensions_by_name
573  for extension_name, extension_field in extension_dict.iteritems():
574    constant_name = extension_name.upper() + "_FIELD_NUMBER"
575    setattr(cls, constant_name, extension_field.number)
576
577
578def _AddStaticMethods(cls):
579  # TODO(robinson): This probably needs to be thread-safe(?)
580  def RegisterExtension(extension_handle):
581    extension_handle.containing_type = cls.DESCRIPTOR
582    _AttachFieldHelpers(cls, extension_handle)
583
584    # Try to insert our extension, failing if an extension with the same number
585    # already exists.
586    actual_handle = cls._extensions_by_number.setdefault(
587        extension_handle.number, extension_handle)
588    if actual_handle is not extension_handle:
589      raise AssertionError(
590          'Extensions "%s" and "%s" both try to extend message type "%s" with '
591          'field number %d.' %
592          (extension_handle.full_name, actual_handle.full_name,
593           cls.DESCRIPTOR.full_name, extension_handle.number))
594
595    cls._extensions_by_name[extension_handle.full_name] = extension_handle
596
597    handle = extension_handle  # avoid line wrapping
598    if _IsMessageSetExtension(handle):
599      # MessageSet extension.  Also register under type name.
600      cls._extensions_by_name[
601          extension_handle.message_type.full_name] = extension_handle
602
603  cls.RegisterExtension = staticmethod(RegisterExtension)
604
605  def FromString(s):
606    message = cls()
607    message.MergeFromString(s)
608    return message
609  cls.FromString = staticmethod(FromString)
610
611
612def _IsPresent(item):
613  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
614  value should be included in the list returned by ListFields()."""
615
616  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
617    return bool(item[1])
618  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
619    return item[1]._is_present_in_parent
620  else:
621    return True
622
623
624def _AddListFieldsMethod(message_descriptor, cls):
625  """Helper for _AddMessageMethods()."""
626
627  def ListFields(self):
628    all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
629    all_fields.sort(key = lambda item: item[0].number)
630    return all_fields
631
632  cls.ListFields = ListFields
633
634
635def _AddHasFieldMethod(message_descriptor, cls):
636  """Helper for _AddMessageMethods()."""
637
638  singular_fields = {}
639  for field in message_descriptor.fields:
640    if field.label != _FieldDescriptor.LABEL_REPEATED:
641      singular_fields[field.name] = field
642
643  def HasField(self, field_name):
644    try:
645      field = singular_fields[field_name]
646    except KeyError:
647      raise ValueError(
648          'Protocol message has no singular "%s" field.' % field_name)
649
650    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
651      value = self._fields.get(field)
652      return value is not None and value._is_present_in_parent
653    else:
654      return field in self._fields
655  cls.HasField = HasField
656
657
658def _AddClearFieldMethod(message_descriptor, cls):
659  """Helper for _AddMessageMethods()."""
660  def ClearField(self, field_name):
661    try:
662      field = message_descriptor.fields_by_name[field_name]
663    except KeyError:
664      raise ValueError('Protocol message has no "%s" field.' % field_name)
665
666    if field in self._fields:
667      # Note:  If the field is a sub-message, its listener will still point
668      #   at us.  That's fine, because the worst than can happen is that it
669      #   will call _Modified() and invalidate our byte size.  Big deal.
670      del self._fields[field]
671
672    # Always call _Modified() -- even if nothing was changed, this is
673    # a mutating method, and thus calling it should cause the field to become
674    # present in the parent message.
675    self._Modified()
676
677  cls.ClearField = ClearField
678
679
680def _AddClearExtensionMethod(cls):
681  """Helper for _AddMessageMethods()."""
682  def ClearExtension(self, extension_handle):
683    _VerifyExtensionHandle(self, extension_handle)
684
685    # Similar to ClearField(), above.
686    if extension_handle in self._fields:
687      del self._fields[extension_handle]
688    self._Modified()
689  cls.ClearExtension = ClearExtension
690
691
692def _AddClearMethod(message_descriptor, cls):
693  """Helper for _AddMessageMethods()."""
694  def Clear(self):
695    # Clear fields.
696    self._fields = {}
697    self._Modified()
698  cls.Clear = Clear
699
700
701def _AddHasExtensionMethod(cls):
702  """Helper for _AddMessageMethods()."""
703  def HasExtension(self, extension_handle):
704    _VerifyExtensionHandle(self, extension_handle)
705    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
706      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
707
708    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
709      value = self._fields.get(extension_handle)
710      return value is not None and value._is_present_in_parent
711    else:
712      return extension_handle in self._fields
713  cls.HasExtension = HasExtension
714
715
716def _AddEqualsMethod(message_descriptor, cls):
717  """Helper for _AddMessageMethods()."""
718  def __eq__(self, other):
719    if (not isinstance(other, message_mod.Message) or
720        other.DESCRIPTOR != self.DESCRIPTOR):
721      return False
722
723    if self is other:
724      return True
725
726    return self.ListFields() == other.ListFields()
727
728  cls.__eq__ = __eq__
729
730
731def _AddStrMethod(message_descriptor, cls):
732  """Helper for _AddMessageMethods()."""
733  def __str__(self):
734    return text_format.MessageToString(self)
735  cls.__str__ = __str__
736
737
738def _AddSetListenerMethod(cls):
739  """Helper for _AddMessageMethods()."""
740  def SetListener(self, listener):
741    if listener is None:
742      self._listener = message_listener_mod.NullMessageListener()
743    else:
744      self._listener = listener
745  cls._SetListener = SetListener
746
747
748def _BytesForNonRepeatedElement(value, field_number, field_type):
749  """Returns the number of bytes needed to serialize a non-repeated element.
750  The returned byte count includes space for tag information and any
751  other additional space associated with serializing value.
752
753  Args:
754    value: Value we're serializing.
755    field_number: Field number of this value.  (Since the field number
756      is stored as part of a varint-encoded tag, this has an impact
757      on the total bytes required to serialize the value).
758    field_type: The type of the field.  One of the TYPE_* constants
759      within FieldDescriptor.
760  """
761  try:
762    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
763    return fn(field_number, value)
764  except KeyError:
765    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
766
767
768def _AddByteSizeMethod(message_descriptor, cls):
769  """Helper for _AddMessageMethods()."""
770
771  def ByteSize(self):
772    if not self._cached_byte_size_dirty:
773      return self._cached_byte_size
774
775    size = 0
776    for field_descriptor, field_value in self.ListFields():
777      size += field_descriptor._sizer(field_value)
778
779    self._cached_byte_size = size
780    self._cached_byte_size_dirty = False
781    self._listener_for_children.dirty = False
782    return size
783
784  cls.ByteSize = ByteSize
785
786
787def _AddSerializeToStringMethod(message_descriptor, cls):
788  """Helper for _AddMessageMethods()."""
789
790  def SerializeToString(self):
791    # Check if the message has all of its required fields set.
792    errors = []
793    if not self.IsInitialized():
794      raise message_mod.EncodeError(
795          'Message is missing required fields: ' +
796          ','.join(self.FindInitializationErrors()))
797    return self.SerializePartialToString()
798  cls.SerializeToString = SerializeToString
799
800
801def _AddSerializePartialToStringMethod(message_descriptor, cls):
802  """Helper for _AddMessageMethods()."""
803
804  def SerializePartialToString(self):
805    out = StringIO()
806    self._InternalSerialize(out.write)
807    return out.getvalue()
808  cls.SerializePartialToString = SerializePartialToString
809
810  def InternalSerialize(self, write_bytes):
811    for field_descriptor, field_value in self.ListFields():
812      field_descriptor._encoder(write_bytes, field_value)
813  cls._InternalSerialize = InternalSerialize
814
815
816def _AddMergeFromStringMethod(message_descriptor, cls):
817  """Helper for _AddMessageMethods()."""
818  def MergeFromString(self, serialized):
819    length = len(serialized)
820    try:
821      if self._InternalParse(serialized, 0, length) != length:
822        # The only reason _InternalParse would return early is if it
823        # encountered an end-group tag.
824        raise message_mod.DecodeError('Unexpected end-group tag.')
825    except IndexError:
826      raise message_mod.DecodeError('Truncated message.')
827    except struct.error, e:
828      raise message_mod.DecodeError(e)
829    return length   # Return this for legacy reasons.
830  cls.MergeFromString = MergeFromString
831
832  local_ReadTag = decoder.ReadTag
833  local_SkipField = decoder.SkipField
834  decoders_by_tag = cls._decoders_by_tag
835
836  def InternalParse(self, buffer, pos, end):
837    self._Modified()
838    field_dict = self._fields
839    while pos != end:
840      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
841      field_decoder = decoders_by_tag.get(tag_bytes)
842      if field_decoder is None:
843        new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
844        if new_pos == -1:
845          return pos
846        pos = new_pos
847      else:
848        pos = field_decoder(buffer, new_pos, end, self, field_dict)
849    return pos
850  cls._InternalParse = InternalParse
851
852
853def _AddIsInitializedMethod(message_descriptor, cls):
854  """Adds the IsInitialized and FindInitializationError methods to the
855  protocol message class."""
856
857  required_fields = [field for field in message_descriptor.fields
858                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
859
860  def IsInitialized(self, errors=None):
861    """Checks if all required fields of a message are set.
862
863    Args:
864      errors:  A list which, if provided, will be populated with the field
865               paths of all missing required fields.
866
867    Returns:
868      True iff the specified message has all required fields set.
869    """
870
871    # Performance is critical so we avoid HasField() and ListFields().
872
873    for field in required_fields:
874      if (field not in self._fields or
875          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
876           not self._fields[field]._is_present_in_parent)):
877        if errors is not None:
878          errors.extend(self.FindInitializationErrors())
879        return False
880
881    for field, value in self._fields.iteritems():
882      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
883        if field.label == _FieldDescriptor.LABEL_REPEATED:
884          for element in value:
885            if not element.IsInitialized():
886              if errors is not None:
887                errors.extend(self.FindInitializationErrors())
888              return False
889        elif value._is_present_in_parent and not value.IsInitialized():
890          if errors is not None:
891            errors.extend(self.FindInitializationErrors())
892          return False
893
894    return True
895
896  cls.IsInitialized = IsInitialized
897
898  def FindInitializationErrors(self):
899    """Finds required fields which are not initialized.
900
901    Returns:
902      A list of strings.  Each string is a path to an uninitialized field from
903      the top-level message, e.g. "foo.bar[5].baz".
904    """
905
906    errors = []  # simplify things
907
908    for field in required_fields:
909      if not self.HasField(field.name):
910        errors.append(field.name)
911
912    for field, value in self.ListFields():
913      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
914        if field.is_extension:
915          name = "(%s)" % field.full_name
916        else:
917          name = field.name
918
919        if field.label == _FieldDescriptor.LABEL_REPEATED:
920          for i in xrange(len(value)):
921            element = value[i]
922            prefix = "%s[%d]." % (name, i)
923            sub_errors = element.FindInitializationErrors()
924            errors += [ prefix + error for error in sub_errors ]
925        else:
926          prefix = name + "."
927          sub_errors = value.FindInitializationErrors()
928          errors += [ prefix + error for error in sub_errors ]
929
930    return errors
931
932  cls.FindInitializationErrors = FindInitializationErrors
933
934
935def _AddMergeFromMethod(cls):
936  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
937  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
938
939  def MergeFrom(self, msg):
940    assert msg is not self
941    self._Modified()
942
943    fields = self._fields
944
945    for field, value in msg._fields.iteritems():
946      if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE:
947        field_value = fields.get(field)
948        if field_value is None:
949          # Construct a new object to represent this field.
950          field_value = field._default_constructor(self)
951          fields[field] = field_value
952        field_value.MergeFrom(value)
953      else:
954        self._fields[field] = value
955  cls.MergeFrom = MergeFrom
956
957
958def _AddMessageMethods(message_descriptor, cls):
959  """Adds implementations of all Message methods to cls."""
960  _AddListFieldsMethod(message_descriptor, cls)
961  _AddHasFieldMethod(message_descriptor, cls)
962  _AddClearFieldMethod(message_descriptor, cls)
963  if message_descriptor.is_extendable:
964    _AddClearExtensionMethod(cls)
965    _AddHasExtensionMethod(cls)
966  _AddClearMethod(message_descriptor, cls)
967  _AddEqualsMethod(message_descriptor, cls)
968  _AddStrMethod(message_descriptor, cls)
969  _AddSetListenerMethod(cls)
970  _AddByteSizeMethod(message_descriptor, cls)
971  _AddSerializeToStringMethod(message_descriptor, cls)
972  _AddSerializePartialToStringMethod(message_descriptor, cls)
973  _AddMergeFromStringMethod(message_descriptor, cls)
974  _AddIsInitializedMethod(message_descriptor, cls)
975  _AddMergeFromMethod(cls)
976
977
978def _AddPrivateHelperMethods(cls):
979  """Adds implementation of private helper methods to cls."""
980
981  def Modified(self):
982    """Sets the _cached_byte_size_dirty bit to true,
983    and propagates this to our listener iff this was a state change.
984    """
985
986    # Note:  Some callers check _cached_byte_size_dirty before calling
987    #   _Modified() as an extra optimization.  So, if this method is ever
988    #   changed such that it does stuff even when _cached_byte_size_dirty is
989    #   already true, the callers need to be updated.
990    if not self._cached_byte_size_dirty:
991      self._cached_byte_size_dirty = True
992      self._listener_for_children.dirty = True
993      self._is_present_in_parent = True
994      self._listener.Modified()
995
996  cls._Modified = Modified
997  cls.SetInParent = Modified
998
999
1000class _Listener(object):
1001
1002  """MessageListener implementation that a parent message registers with its
1003  child message.
1004
1005  In order to support semantics like:
1006
1007    foo.bar.baz.qux = 23
1008    assert foo.HasField('bar')
1009
1010  ...child objects must have back references to their parents.
1011  This helper class is at the heart of this support.
1012  """
1013
1014  def __init__(self, parent_message):
1015    """Args:
1016      parent_message: The message whose _Modified() method we should call when
1017        we receive Modified() messages.
1018    """
1019    # This listener establishes a back reference from a child (contained) object
1020    # to its parent (containing) object.  We make this a weak reference to avoid
1021    # creating cyclic garbage when the client finishes with the 'parent' object
1022    # in the tree.
1023    if isinstance(parent_message, weakref.ProxyType):
1024      self._parent_message_weakref = parent_message
1025    else:
1026      self._parent_message_weakref = weakref.proxy(parent_message)
1027
1028    # As an optimization, we also indicate directly on the listener whether
1029    # or not the parent message is dirty.  This way we can avoid traversing
1030    # up the tree in the common case.
1031    self.dirty = False
1032
1033  def Modified(self):
1034    if self.dirty:
1035      return
1036    try:
1037      # Propagate the signal to our parents iff this is the first field set.
1038      self._parent_message_weakref._Modified()
1039    except ReferenceError:
1040      # We can get here if a client has kept a reference to a child object,
1041      # and is now setting a field on it, but the child's parent has been
1042      # garbage-collected.  This is not an error.
1043      pass
1044
1045
1046# TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1047# TODO(robinson): Unify error handling of "unknown extension" crap.
1048# TODO(robinson): Support iteritems()-style iteration over all
1049# extensions with the "has" bits turned on?
1050class _ExtensionDict(object):
1051
1052  """Dict-like container for supporting an indexable "Extensions"
1053  field on proto instances.
1054
1055  Note that in all cases we expect extension handles to be
1056  FieldDescriptors.
1057  """
1058
1059  def __init__(self, extended_message):
1060    """extended_message: Message instance for which we are the Extensions dict.
1061    """
1062
1063    self._extended_message = extended_message
1064
1065  def __getitem__(self, extension_handle):
1066    """Returns the current value of the given extension handle."""
1067
1068    _VerifyExtensionHandle(self._extended_message, extension_handle)
1069
1070    result = self._extended_message._fields.get(extension_handle)
1071    if result is not None:
1072      return result
1073
1074    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1075      result = extension_handle._default_constructor(self._extended_message)
1076    elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1077      result = extension_handle.message_type._concrete_class()
1078      try:
1079        result._SetListener(self._extended_message._listener_for_children)
1080      except ReferenceError:
1081        pass
1082    else:
1083      # Singular scalar -- just return the default without inserting into the
1084      # dict.
1085      return extension_handle.default_value
1086
1087    # Atomically check if another thread has preempted us and, if not, swap
1088    # in the new object we just created.  If someone has preempted us, we
1089    # take that object and discard ours.
1090    # WARNING:  We are relying on setdefault() being atomic.  This is true
1091    #   in CPython but we haven't investigated others.  This warning appears
1092    #   in several other locations in this file.
1093    result = self._extended_message._fields.setdefault(
1094        extension_handle, result)
1095
1096    return result
1097
1098  def __eq__(self, other):
1099    if not isinstance(other, self.__class__):
1100      return False
1101
1102    my_fields = self._extended_message.ListFields()
1103    other_fields = other._extended_message.ListFields()
1104
1105    # Get rid of non-extension fields.
1106    my_fields    = [ field for field in my_fields    if field.is_extension ]
1107    other_fields = [ field for field in other_fields if field.is_extension ]
1108
1109    return my_fields == other_fields
1110
1111  def __ne__(self, other):
1112    return not self == other
1113
1114  # Note that this is only meaningful for non-repeated, scalar extension
1115  # fields.  Note also that we may have to call _Modified() when we do
1116  # successfully set a field this way, to set any necssary "has" bits in the
1117  # ancestors of the extended message.
1118  def __setitem__(self, extension_handle, value):
1119    """If extension_handle specifies a non-repeated, scalar extension
1120    field, sets the value of that field.
1121    """
1122
1123    _VerifyExtensionHandle(self._extended_message, extension_handle)
1124
1125    if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1126        extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1127      raise TypeError(
1128          'Cannot assign to extension "%s" because it is a repeated or '
1129          'composite type.' % extension_handle.full_name)
1130
1131    # It's slightly wasteful to lookup the type checker each time,
1132    # but we expect this to be a vanishingly uncommon case anyway.
1133    type_checker = type_checkers.GetTypeChecker(
1134        extension_handle.cpp_type, extension_handle.type)
1135    type_checker.CheckValue(value)
1136    self._extended_message._fields[extension_handle] = value
1137    self._extended_message._Modified()
1138
1139  def _FindExtensionByName(self, name):
1140    """Tries to find a known extension with the specified name.
1141
1142    Args:
1143      name: Extension full name.
1144
1145    Returns:
1146      Extension field descriptor.
1147    """
1148    return self._extended_message._extensions_by_name.get(name, None)
1149