1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# Keep it Python2.5 compatible for GAE.
32#
33# Copyright 2007 Google Inc. All Rights Reserved.
34#
35# This code is meant to work on Python 2.4 and above only.
36#
37# TODO(robinson): Helpers for verbose, common checks like seeing if a
38# descriptor's cpp_type is CPPTYPE_MESSAGE.
39
40"""Contains a metaclass and helper functions used to create
41protocol message classes from Descriptor objects at runtime.
42
43Recall that a metaclass is the "type" of a class.
44(A class is to a metaclass what an instance is to a class.)
45
46In this case, we use the GeneratedProtocolMessageType metaclass
47to inject all the useful functionality into the classes
48output by the protocol compiler at compile-time.
49
50The upshot of all this is that the real implementation
51details for ALL pure-Python protocol buffers are *here in
52this file*.
53"""
54
55__author__ = 'robinson@google.com (Will Robinson)'
56
57import sys
58if sys.version_info[0] < 3:
59  try:
60    from cStringIO import StringIO as BytesIO
61  except ImportError:
62    from StringIO import StringIO as BytesIO
63  import copy_reg as copyreg
64else:
65  from io import BytesIO
66  import copyreg
67import struct
68import weakref
69
70# We use "as" to avoid name collisions with variables.
71from google.protobuf.internal import containers
72from google.protobuf.internal import decoder
73from google.protobuf.internal import encoder
74from google.protobuf.internal import enum_type_wrapper
75from google.protobuf.internal import message_listener as message_listener_mod
76from google.protobuf.internal import type_checkers
77from google.protobuf.internal import wire_format
78from google.protobuf import descriptor as descriptor_mod
79from google.protobuf import message as message_mod
80from google.protobuf import text_format
81
82_FieldDescriptor = descriptor_mod.FieldDescriptor
83
84
85def NewMessage(bases, descriptor, dictionary):
86  _AddClassAttributesForNestedExtensions(descriptor, dictionary)
87  _AddSlots(descriptor, dictionary)
88  return bases
89
90
91def InitMessage(descriptor, cls):
92  cls._decoders_by_tag = {}
93  cls._extensions_by_name = {}
94  cls._extensions_by_number = {}
95  if (descriptor.has_options and
96      descriptor.GetOptions().message_set_wire_format):
97    cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
98        decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
99
100  # Attach stuff to each FieldDescriptor for quick lookup later on.
101  for field in descriptor.fields:
102    _AttachFieldHelpers(cls, field)
103
104  _AddEnumValues(descriptor, cls)
105  _AddInitMethod(descriptor, cls)
106  _AddPropertiesForFields(descriptor, cls)
107  _AddPropertiesForExtensions(descriptor, cls)
108  _AddStaticMethods(cls)
109  _AddMessageMethods(descriptor, cls)
110  _AddPrivateHelperMethods(descriptor, cls)
111  copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
112
113
114# Stateless helpers for GeneratedProtocolMessageType below.
115# Outside clients should not access these directly.
116#
117# I opted not to make any of these methods on the metaclass, to make it more
118# clear that I'm not really using any state there and to keep clients from
119# thinking that they have direct access to these construction helpers.
120
121
122def _PropertyName(proto_field_name):
123  """Returns the name of the public property attribute which
124  clients can use to get and (in some cases) set the value
125  of a protocol message field.
126
127  Args:
128    proto_field_name: The protocol message field name, exactly
129      as it appears (or would appear) in a .proto file.
130  """
131  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
132  # nnorwitz makes my day by writing:
133  # """
134  # FYI.  See the keyword module in the stdlib. This could be as simple as:
135  #
136  # if keyword.iskeyword(proto_field_name):
137  #   return proto_field_name + "_"
138  # return proto_field_name
139  # """
140  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
141  #   getattr() and setattr() to reflectively manipulate field values.  If we
142  #   rename the properties, then every such user has to also make sure to apply
143  #   the same transformation.  Note that currently if you name a field "yield",
144  #   you can still access it just fine using getattr/setattr -- it's not even
145  #   that cumbersome to do so.
146  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
147  #   position.
148  return proto_field_name
149
150
151def _VerifyExtensionHandle(message, extension_handle):
152  """Verify that the given extension handle is valid."""
153
154  if not isinstance(extension_handle, _FieldDescriptor):
155    raise KeyError('HasExtension() expects an extension handle, got: %s' %
156                   extension_handle)
157
158  if not extension_handle.is_extension:
159    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
160
161  if not extension_handle.containing_type:
162    raise KeyError('"%s" is missing a containing_type.'
163                   % extension_handle.full_name)
164
165  if extension_handle.containing_type is not message.DESCRIPTOR:
166    raise KeyError('Extension "%s" extends message type "%s", but this '
167                   'message is of type "%s".' %
168                   (extension_handle.full_name,
169                    extension_handle.containing_type.full_name,
170                    message.DESCRIPTOR.full_name))
171
172
173def _AddSlots(message_descriptor, dictionary):
174  """Adds a __slots__ entry to dictionary, containing the names of all valid
175  attributes for this message type.
176
177  Args:
178    message_descriptor: A Descriptor instance describing this message type.
179    dictionary: Class dictionary to which we'll add a '__slots__' entry.
180  """
181  dictionary['__slots__'] = ['_cached_byte_size',
182                             '_cached_byte_size_dirty',
183                             '_fields',
184                             '_unknown_fields',
185                             '_is_present_in_parent',
186                             '_listener',
187                             '_listener_for_children',
188                             '__weakref__',
189                             '_oneofs']
190
191
192def _IsMessageSetExtension(field):
193  return (field.is_extension and
194          field.containing_type.has_options and
195          field.containing_type.GetOptions().message_set_wire_format and
196          field.type == _FieldDescriptor.TYPE_MESSAGE and
197          field.message_type == field.extension_scope and
198          field.label == _FieldDescriptor.LABEL_OPTIONAL)
199
200
201def _AttachFieldHelpers(cls, field_descriptor):
202  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
203  is_packed = (field_descriptor.has_options and
204               field_descriptor.GetOptions().packed)
205
206  if _IsMessageSetExtension(field_descriptor):
207    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
208    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
209  else:
210    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
211        field_descriptor.number, is_repeated, is_packed)
212    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
213        field_descriptor.number, is_repeated, is_packed)
214
215  field_descriptor._encoder = field_encoder
216  field_descriptor._sizer = sizer
217  field_descriptor._default_constructor = _DefaultValueConstructorForField(
218      field_descriptor)
219
220  def AddDecoder(wiretype, is_packed):
221    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
222    cls._decoders_by_tag[tag_bytes] = (
223        type_checkers.TYPE_TO_DECODER[field_descriptor.type](
224            field_descriptor.number, is_repeated, is_packed,
225            field_descriptor, field_descriptor._default_constructor),
226        field_descriptor if field_descriptor.containing_oneof is not None
227        else None)
228
229  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
230             False)
231
232  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
233    # To support wire compatibility of adding packed = true, add a decoder for
234    # packed values regardless of the field's options.
235    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
236
237
238def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
239  extension_dict = descriptor.extensions_by_name
240  for extension_name, extension_field in extension_dict.iteritems():
241    assert extension_name not in dictionary
242    dictionary[extension_name] = extension_field
243
244
245def _AddEnumValues(descriptor, cls):
246  """Sets class-level attributes for all enum fields defined in this message.
247
248  Also exporting a class-level object that can name enum values.
249
250  Args:
251    descriptor: Descriptor object for this message type.
252    cls: Class we're constructing for this message type.
253  """
254  for enum_type in descriptor.enum_types:
255    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
256    for enum_value in enum_type.values:
257      setattr(cls, enum_value.name, enum_value.number)
258
259
260def _DefaultValueConstructorForField(field):
261  """Returns a function which returns a default value for a field.
262
263  Args:
264    field: FieldDescriptor object for this field.
265
266  The returned function has one argument:
267    message: Message instance containing this field, or a weakref proxy
268      of same.
269
270  That function in turn returns a default value for this field.  The default
271    value may refer back to |message| via a weak reference.
272  """
273
274  if field.label == _FieldDescriptor.LABEL_REPEATED:
275    if field.has_default_value and field.default_value != []:
276      raise ValueError('Repeated field default value not empty list: %s' % (
277          field.default_value))
278    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
279      # We can't look at _concrete_class yet since it might not have
280      # been set.  (Depends on order in which we initialize the classes).
281      message_type = field.message_type
282      def MakeRepeatedMessageDefault(message):
283        return containers.RepeatedCompositeFieldContainer(
284            message._listener_for_children, field.message_type)
285      return MakeRepeatedMessageDefault
286    else:
287      type_checker = type_checkers.GetTypeChecker(field)
288      def MakeRepeatedScalarDefault(message):
289        return containers.RepeatedScalarFieldContainer(
290            message._listener_for_children, type_checker)
291      return MakeRepeatedScalarDefault
292
293  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
294    # _concrete_class may not yet be initialized.
295    message_type = field.message_type
296    def MakeSubMessageDefault(message):
297      result = message_type._concrete_class()
298      result._SetListener(message._listener_for_children)
299      return result
300    return MakeSubMessageDefault
301
302  def MakeScalarDefault(message):
303    # TODO(protobuf-team): This may be broken since there may not be
304    # default_value.  Combine with has_default_value somehow.
305    return field.default_value
306  return MakeScalarDefault
307
308
309def _AddInitMethod(message_descriptor, cls):
310  """Adds an __init__ method to cls."""
311  fields = message_descriptor.fields
312  def init(self, **kwargs):
313    self._cached_byte_size = 0
314    self._cached_byte_size_dirty = len(kwargs) > 0
315    self._fields = {}
316    # Contains a mapping from oneof field descriptors to the descriptor
317    # of the currently set field in that oneof field.
318    self._oneofs = {}
319
320    # _unknown_fields is () when empty for efficiency, and will be turned into
321    # a list if fields are added.
322    self._unknown_fields = ()
323    self._is_present_in_parent = False
324    self._listener = message_listener_mod.NullMessageListener()
325    self._listener_for_children = _Listener(self)
326    for field_name, field_value in kwargs.iteritems():
327      field = _GetFieldByName(message_descriptor, field_name)
328      if field is None:
329        raise TypeError("%s() got an unexpected keyword argument '%s'" %
330                        (message_descriptor.name, field_name))
331      if field.label == _FieldDescriptor.LABEL_REPEATED:
332        copy = field._default_constructor(self)
333        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
334          for val in field_value:
335            copy.add().MergeFrom(val)
336        else:  # Scalar
337          copy.extend(field_value)
338        self._fields[field] = copy
339      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
340        copy = field._default_constructor(self)
341        copy.MergeFrom(field_value)
342        self._fields[field] = copy
343      else:
344        setattr(self, field_name, field_value)
345
346  init.__module__ = None
347  init.__doc__ = None
348  cls.__init__ = init
349
350
351def _GetFieldByName(message_descriptor, field_name):
352  """Returns a field descriptor by field name.
353
354  Args:
355    message_descriptor: A Descriptor describing all fields in message.
356    field_name: The name of the field to retrieve.
357  Returns:
358    The field descriptor associated with the field name.
359  """
360  try:
361    return message_descriptor.fields_by_name[field_name]
362  except KeyError:
363    raise ValueError('Protocol message has no "%s" field.' % field_name)
364
365
366def _AddPropertiesForFields(descriptor, cls):
367  """Adds properties for all fields in this protocol message type."""
368  for field in descriptor.fields:
369    _AddPropertiesForField(field, cls)
370
371  if descriptor.is_extendable:
372    # _ExtensionDict is just an adaptor with no state so we allocate a new one
373    # every time it is accessed.
374    cls.Extensions = property(lambda self: _ExtensionDict(self))
375
376
377def _AddPropertiesForField(field, cls):
378  """Adds a public property for a protocol message field.
379  Clients can use this property to get and (in the case
380  of non-repeated scalar fields) directly set the value
381  of a protocol message field.
382
383  Args:
384    field: A FieldDescriptor for this field.
385    cls: The class we're constructing.
386  """
387  # Catch it if we add other types that we should
388  # handle specially here.
389  assert _FieldDescriptor.MAX_CPPTYPE == 10
390
391  constant_name = field.name.upper() + "_FIELD_NUMBER"
392  setattr(cls, constant_name, field.number)
393
394  if field.label == _FieldDescriptor.LABEL_REPEATED:
395    _AddPropertiesForRepeatedField(field, cls)
396  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
397    _AddPropertiesForNonRepeatedCompositeField(field, cls)
398  else:
399    _AddPropertiesForNonRepeatedScalarField(field, cls)
400
401
402def _AddPropertiesForRepeatedField(field, cls):
403  """Adds a public property for a "repeated" protocol message field.  Clients
404  can use this property to get the value of the field, which will be either a
405  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
406  below).
407
408  Note that when clients add values to these containers, we perform
409  type-checking in the case of repeated scalar fields, and we also set any
410  necessary "has" bits as a side-effect.
411
412  Args:
413    field: A FieldDescriptor for this field.
414    cls: The class we're constructing.
415  """
416  proto_field_name = field.name
417  property_name = _PropertyName(proto_field_name)
418
419  def getter(self):
420    field_value = self._fields.get(field)
421    if field_value is None:
422      # Construct a new object to represent this field.
423      field_value = field._default_constructor(self)
424
425      # Atomically check if another thread has preempted us and, if not, swap
426      # in the new object we just created.  If someone has preempted us, we
427      # take that object and discard ours.
428      # WARNING:  We are relying on setdefault() being atomic.  This is true
429      #   in CPython but we haven't investigated others.  This warning appears
430      #   in several other locations in this file.
431      field_value = self._fields.setdefault(field, field_value)
432    return field_value
433  getter.__module__ = None
434  getter.__doc__ = 'Getter for %s.' % proto_field_name
435
436  # We define a setter just so we can throw an exception with a more
437  # helpful error message.
438  def setter(self, new_value):
439    raise AttributeError('Assignment not allowed to repeated field '
440                         '"%s" in protocol message object.' % proto_field_name)
441
442  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
443  setattr(cls, property_name, property(getter, setter, doc=doc))
444
445
446def _AddPropertiesForNonRepeatedScalarField(field, cls):
447  """Adds a public property for a nonrepeated, scalar protocol message field.
448  Clients can use this property to get and directly set the value of the field.
449  Note that when the client sets the value of a field by using this property,
450  all necessary "has" bits are set as a side-effect, and we also perform
451  type-checking.
452
453  Args:
454    field: A FieldDescriptor for this field.
455    cls: The class we're constructing.
456  """
457  proto_field_name = field.name
458  property_name = _PropertyName(proto_field_name)
459  type_checker = type_checkers.GetTypeChecker(field)
460  default_value = field.default_value
461  valid_values = set()
462
463  def getter(self):
464    # TODO(protobuf-team): This may be broken since there may not be
465    # default_value.  Combine with has_default_value somehow.
466    return self._fields.get(field, default_value)
467  getter.__module__ = None
468  getter.__doc__ = 'Getter for %s.' % proto_field_name
469  def field_setter(self, new_value):
470    # pylint: disable=protected-access
471    self._fields[field] = type_checker.CheckValue(new_value)
472    # Check _cached_byte_size_dirty inline to improve performance, since scalar
473    # setters are called frequently.
474    if not self._cached_byte_size_dirty:
475      self._Modified()
476
477  if field.containing_oneof is not None:
478    def setter(self, new_value):
479      field_setter(self, new_value)
480      self._UpdateOneofState(field)
481  else:
482    setter = field_setter
483
484  setter.__module__ = None
485  setter.__doc__ = 'Setter for %s.' % proto_field_name
486
487  # Add a property to encapsulate the getter/setter.
488  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
489  setattr(cls, property_name, property(getter, setter, doc=doc))
490
491
492def _AddPropertiesForNonRepeatedCompositeField(field, cls):
493  """Adds a public property for a nonrepeated, composite protocol message field.
494  A composite field is a "group" or "message" field.
495
496  Clients can use this property to get the value of the field, but cannot
497  assign to the property directly.
498
499  Args:
500    field: A FieldDescriptor for this field.
501    cls: The class we're constructing.
502  """
503  # TODO(robinson): Remove duplication with similar method
504  # for non-repeated scalars.
505  proto_field_name = field.name
506  property_name = _PropertyName(proto_field_name)
507
508  # TODO(komarek): Can anyone explain to me why we cache the message_type this
509  # way, instead of referring to field.message_type inside of getter(self)?
510  # What if someone sets message_type later on (which makes for simpler
511  # dyanmic proto descriptor and class creation code).
512  message_type = field.message_type
513
514  def getter(self):
515    field_value = self._fields.get(field)
516    if field_value is None:
517      # Construct a new object to represent this field.
518      field_value = message_type._concrete_class()  # use field.message_type?
519      field_value._SetListener(
520          _OneofListener(self, field)
521          if field.containing_oneof is not None
522          else self._listener_for_children)
523
524      # Atomically check if another thread has preempted us and, if not, swap
525      # in the new object we just created.  If someone has preempted us, we
526      # take that object and discard ours.
527      # WARNING:  We are relying on setdefault() being atomic.  This is true
528      #   in CPython but we haven't investigated others.  This warning appears
529      #   in several other locations in this file.
530      field_value = self._fields.setdefault(field, field_value)
531    return field_value
532  getter.__module__ = None
533  getter.__doc__ = 'Getter for %s.' % proto_field_name
534
535  # We define a setter just so we can throw an exception with a more
536  # helpful error message.
537  def setter(self, new_value):
538    raise AttributeError('Assignment not allowed to composite field '
539                         '"%s" in protocol message object.' % proto_field_name)
540
541  # Add a property to encapsulate the getter.
542  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
543  setattr(cls, property_name, property(getter, setter, doc=doc))
544
545
546def _AddPropertiesForExtensions(descriptor, cls):
547  """Adds properties for all fields in this protocol message type."""
548  extension_dict = descriptor.extensions_by_name
549  for extension_name, extension_field in extension_dict.iteritems():
550    constant_name = extension_name.upper() + "_FIELD_NUMBER"
551    setattr(cls, constant_name, extension_field.number)
552
553
554def _AddStaticMethods(cls):
555  # TODO(robinson): This probably needs to be thread-safe(?)
556  def RegisterExtension(extension_handle):
557    extension_handle.containing_type = cls.DESCRIPTOR
558    _AttachFieldHelpers(cls, extension_handle)
559
560    # Try to insert our extension, failing if an extension with the same number
561    # already exists.
562    actual_handle = cls._extensions_by_number.setdefault(
563        extension_handle.number, extension_handle)
564    if actual_handle is not extension_handle:
565      raise AssertionError(
566          'Extensions "%s" and "%s" both try to extend message type "%s" with '
567          'field number %d.' %
568          (extension_handle.full_name, actual_handle.full_name,
569           cls.DESCRIPTOR.full_name, extension_handle.number))
570
571    cls._extensions_by_name[extension_handle.full_name] = extension_handle
572
573    handle = extension_handle  # avoid line wrapping
574    if _IsMessageSetExtension(handle):
575      # MessageSet extension.  Also register under type name.
576      cls._extensions_by_name[
577          extension_handle.message_type.full_name] = extension_handle
578
579  cls.RegisterExtension = staticmethod(RegisterExtension)
580
581  def FromString(s):
582    message = cls()
583    message.MergeFromString(s)
584    return message
585  cls.FromString = staticmethod(FromString)
586
587
588def _IsPresent(item):
589  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
590  value should be included in the list returned by ListFields()."""
591
592  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
593    return bool(item[1])
594  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
595    return item[1]._is_present_in_parent
596  else:
597    return True
598
599
600def _AddListFieldsMethod(message_descriptor, cls):
601  """Helper for _AddMessageMethods()."""
602
603  def ListFields(self):
604    all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
605    all_fields.sort(key = lambda item: item[0].number)
606    return all_fields
607
608  cls.ListFields = ListFields
609
610
611def _AddHasFieldMethod(message_descriptor, cls):
612  """Helper for _AddMessageMethods()."""
613
614  singular_fields = {}
615  for field in message_descriptor.fields:
616    if field.label != _FieldDescriptor.LABEL_REPEATED:
617      singular_fields[field.name] = field
618  # Fields inside oneofs are never repeated (enforced by the compiler).
619  for field in message_descriptor.oneofs:
620    singular_fields[field.name] = field
621
622  def HasField(self, field_name):
623    try:
624      field = singular_fields[field_name]
625    except KeyError:
626      raise ValueError(
627          'Protocol message has no singular "%s" field.' % field_name)
628
629    if isinstance(field, descriptor_mod.OneofDescriptor):
630      try:
631        return HasField(self, self._oneofs[field].name)
632      except KeyError:
633        return False
634    else:
635      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
636        value = self._fields.get(field)
637        return value is not None and value._is_present_in_parent
638      else:
639        return field in self._fields
640
641  cls.HasField = HasField
642
643
644def _AddClearFieldMethod(message_descriptor, cls):
645  """Helper for _AddMessageMethods()."""
646  def ClearField(self, field_name):
647    try:
648      field = message_descriptor.fields_by_name[field_name]
649    except KeyError:
650      try:
651        field = message_descriptor.oneofs_by_name[field_name]
652        if field in self._oneofs:
653          field = self._oneofs[field]
654        else:
655          return
656      except KeyError:
657        raise ValueError('Protocol message has no "%s" field.' % field_name)
658
659    if field in self._fields:
660      # Note:  If the field is a sub-message, its listener will still point
661      #   at us.  That's fine, because the worst than can happen is that it
662      #   will call _Modified() and invalidate our byte size.  Big deal.
663      del self._fields[field]
664
665      if self._oneofs.get(field.containing_oneof, None) is field:
666        del self._oneofs[field.containing_oneof]
667
668    # Always call _Modified() -- even if nothing was changed, this is
669    # a mutating method, and thus calling it should cause the field to become
670    # present in the parent message.
671    self._Modified()
672
673  cls.ClearField = ClearField
674
675
676def _AddClearExtensionMethod(cls):
677  """Helper for _AddMessageMethods()."""
678  def ClearExtension(self, extension_handle):
679    _VerifyExtensionHandle(self, extension_handle)
680
681    # Similar to ClearField(), above.
682    if extension_handle in self._fields:
683      del self._fields[extension_handle]
684    self._Modified()
685  cls.ClearExtension = ClearExtension
686
687
688def _AddClearMethod(message_descriptor, cls):
689  """Helper for _AddMessageMethods()."""
690  def Clear(self):
691    # Clear fields.
692    self._fields = {}
693    self._unknown_fields = ()
694    self._Modified()
695  cls.Clear = Clear
696
697
698def _AddHasExtensionMethod(cls):
699  """Helper for _AddMessageMethods()."""
700  def HasExtension(self, extension_handle):
701    _VerifyExtensionHandle(self, extension_handle)
702    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
703      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
704
705    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
706      value = self._fields.get(extension_handle)
707      return value is not None and value._is_present_in_parent
708    else:
709      return extension_handle in self._fields
710  cls.HasExtension = HasExtension
711
712
713def _AddEqualsMethod(message_descriptor, cls):
714  """Helper for _AddMessageMethods()."""
715  def __eq__(self, other):
716    if (not isinstance(other, message_mod.Message) or
717        other.DESCRIPTOR != self.DESCRIPTOR):
718      return False
719
720    if self is other:
721      return True
722
723    if not self.ListFields() == other.ListFields():
724      return False
725
726    # Sort unknown fields because their order shouldn't affect equality test.
727    unknown_fields = list(self._unknown_fields)
728    unknown_fields.sort()
729    other_unknown_fields = list(other._unknown_fields)
730    other_unknown_fields.sort()
731
732    return unknown_fields == other_unknown_fields
733
734  cls.__eq__ = __eq__
735
736
737def _AddStrMethod(message_descriptor, cls):
738  """Helper for _AddMessageMethods()."""
739  def __str__(self):
740    return text_format.MessageToString(self)
741  cls.__str__ = __str__
742
743
744def _AddUnicodeMethod(unused_message_descriptor, cls):
745  """Helper for _AddMessageMethods()."""
746
747  def __unicode__(self):
748    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
749  cls.__unicode__ = __unicode__
750
751
752def _AddSetListenerMethod(cls):
753  """Helper for _AddMessageMethods()."""
754  def SetListener(self, listener):
755    if listener is None:
756      self._listener = message_listener_mod.NullMessageListener()
757    else:
758      self._listener = listener
759  cls._SetListener = SetListener
760
761
762def _BytesForNonRepeatedElement(value, field_number, field_type):
763  """Returns the number of bytes needed to serialize a non-repeated element.
764  The returned byte count includes space for tag information and any
765  other additional space associated with serializing value.
766
767  Args:
768    value: Value we're serializing.
769    field_number: Field number of this value.  (Since the field number
770      is stored as part of a varint-encoded tag, this has an impact
771      on the total bytes required to serialize the value).
772    field_type: The type of the field.  One of the TYPE_* constants
773      within FieldDescriptor.
774  """
775  try:
776    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
777    return fn(field_number, value)
778  except KeyError:
779    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
780
781
782def _AddByteSizeMethod(message_descriptor, cls):
783  """Helper for _AddMessageMethods()."""
784
785  def ByteSize(self):
786    if not self._cached_byte_size_dirty:
787      return self._cached_byte_size
788
789    size = 0
790    for field_descriptor, field_value in self.ListFields():
791      size += field_descriptor._sizer(field_value)
792
793    for tag_bytes, value_bytes in self._unknown_fields:
794      size += len(tag_bytes) + len(value_bytes)
795
796    self._cached_byte_size = size
797    self._cached_byte_size_dirty = False
798    self._listener_for_children.dirty = False
799    return size
800
801  cls.ByteSize = ByteSize
802
803
804def _AddSerializeToStringMethod(message_descriptor, cls):
805  """Helper for _AddMessageMethods()."""
806
807  def SerializeToString(self):
808    # Check if the message has all of its required fields set.
809    errors = []
810    if not self.IsInitialized():
811      raise message_mod.EncodeError(
812          'Message %s is missing required fields: %s' % (
813          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
814    return self.SerializePartialToString()
815  cls.SerializeToString = SerializeToString
816
817
818def _AddSerializePartialToStringMethod(message_descriptor, cls):
819  """Helper for _AddMessageMethods()."""
820
821  def SerializePartialToString(self):
822    out = BytesIO()
823    self._InternalSerialize(out.write)
824    return out.getvalue()
825  cls.SerializePartialToString = SerializePartialToString
826
827  def InternalSerialize(self, write_bytes):
828    for field_descriptor, field_value in self.ListFields():
829      field_descriptor._encoder(write_bytes, field_value)
830    for tag_bytes, value_bytes in self._unknown_fields:
831      write_bytes(tag_bytes)
832      write_bytes(value_bytes)
833  cls._InternalSerialize = InternalSerialize
834
835
836def _AddMergeFromStringMethod(message_descriptor, cls):
837  """Helper for _AddMessageMethods()."""
838  def MergeFromString(self, serialized):
839    length = len(serialized)
840    try:
841      if self._InternalParse(serialized, 0, length) != length:
842        # The only reason _InternalParse would return early is if it
843        # encountered an end-group tag.
844        raise message_mod.DecodeError('Unexpected end-group tag.')
845    except (IndexError, TypeError):
846      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
847      raise message_mod.DecodeError('Truncated message.')
848    except struct.error, e:
849      raise message_mod.DecodeError(e)
850    return length   # Return this for legacy reasons.
851  cls.MergeFromString = MergeFromString
852
853  local_ReadTag = decoder.ReadTag
854  local_SkipField = decoder.SkipField
855  decoders_by_tag = cls._decoders_by_tag
856
857  def InternalParse(self, buffer, pos, end):
858    self._Modified()
859    field_dict = self._fields
860    unknown_field_list = self._unknown_fields
861    while pos != end:
862      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
863      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
864      if field_decoder is None:
865        value_start_pos = new_pos
866        new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
867        if new_pos == -1:
868          return pos
869        if not unknown_field_list:
870          unknown_field_list = self._unknown_fields = []
871        unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
872        pos = new_pos
873      else:
874        pos = field_decoder(buffer, new_pos, end, self, field_dict)
875        if field_desc:
876          self._UpdateOneofState(field_desc)
877    return pos
878  cls._InternalParse = InternalParse
879
880
881def _AddIsInitializedMethod(message_descriptor, cls):
882  """Adds the IsInitialized and FindInitializationError methods to the
883  protocol message class."""
884
885  required_fields = [field for field in message_descriptor.fields
886                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
887
888  def IsInitialized(self, errors=None):
889    """Checks if all required fields of a message are set.
890
891    Args:
892      errors:  A list which, if provided, will be populated with the field
893               paths of all missing required fields.
894
895    Returns:
896      True iff the specified message has all required fields set.
897    """
898
899    # Performance is critical so we avoid HasField() and ListFields().
900
901    for field in required_fields:
902      if (field not in self._fields or
903          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
904           not self._fields[field]._is_present_in_parent)):
905        if errors is not None:
906          errors.extend(self.FindInitializationErrors())
907        return False
908
909    for field, value in list(self._fields.items()):  # dict can change size!
910      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
911        if field.label == _FieldDescriptor.LABEL_REPEATED:
912          for element in value:
913            if not element.IsInitialized():
914              if errors is not None:
915                errors.extend(self.FindInitializationErrors())
916              return False
917        elif value._is_present_in_parent and not value.IsInitialized():
918          if errors is not None:
919            errors.extend(self.FindInitializationErrors())
920          return False
921
922    return True
923
924  cls.IsInitialized = IsInitialized
925
926  def FindInitializationErrors(self):
927    """Finds required fields which are not initialized.
928
929    Returns:
930      A list of strings.  Each string is a path to an uninitialized field from
931      the top-level message, e.g. "foo.bar[5].baz".
932    """
933
934    errors = []  # simplify things
935
936    for field in required_fields:
937      if not self.HasField(field.name):
938        errors.append(field.name)
939
940    for field, value in self.ListFields():
941      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
942        if field.is_extension:
943          name = "(%s)" % field.full_name
944        else:
945          name = field.name
946
947        if field.label == _FieldDescriptor.LABEL_REPEATED:
948          for i in xrange(len(value)):
949            element = value[i]
950            prefix = "%s[%d]." % (name, i)
951            sub_errors = element.FindInitializationErrors()
952            errors += [ prefix + error for error in sub_errors ]
953        else:
954          prefix = name + "."
955          sub_errors = value.FindInitializationErrors()
956          errors += [ prefix + error for error in sub_errors ]
957
958    return errors
959
960  cls.FindInitializationErrors = FindInitializationErrors
961
962
963def _AddMergeFromMethod(cls):
964  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
965  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
966
967  def MergeFrom(self, msg):
968    if not isinstance(msg, cls):
969      raise TypeError(
970          "Parameter to MergeFrom() must be instance of same class: "
971          "expected %s got %s." % (cls.__name__, type(msg).__name__))
972
973    assert msg is not self
974    self._Modified()
975
976    fields = self._fields
977
978    for field, value in msg._fields.iteritems():
979      if field.label == LABEL_REPEATED:
980        field_value = fields.get(field)
981        if field_value is None:
982          # Construct a new object to represent this field.
983          field_value = field._default_constructor(self)
984          fields[field] = field_value
985        field_value.MergeFrom(value)
986      elif field.cpp_type == CPPTYPE_MESSAGE:
987        if value._is_present_in_parent:
988          field_value = fields.get(field)
989          if field_value is None:
990            # Construct a new object to represent this field.
991            field_value = field._default_constructor(self)
992            fields[field] = field_value
993          field_value.MergeFrom(value)
994      else:
995        self._fields[field] = value
996
997    if msg._unknown_fields:
998      if not self._unknown_fields:
999        self._unknown_fields = []
1000      self._unknown_fields.extend(msg._unknown_fields)
1001
1002  cls.MergeFrom = MergeFrom
1003
1004
1005def _AddWhichOneofMethod(message_descriptor, cls):
1006  def WhichOneof(self, oneof_name):
1007    """Returns the name of the currently set field inside a oneof, or None."""
1008    try:
1009      field = message_descriptor.oneofs_by_name[oneof_name]
1010    except KeyError:
1011      raise ValueError(
1012          'Protocol message has no oneof "%s" field.' % oneof_name)
1013
1014    nested_field = self._oneofs.get(field, None)
1015    if nested_field is not None and self.HasField(nested_field.name):
1016      return nested_field.name
1017    else:
1018      return None
1019
1020  cls.WhichOneof = WhichOneof
1021
1022
1023def _AddMessageMethods(message_descriptor, cls):
1024  """Adds implementations of all Message methods to cls."""
1025  _AddListFieldsMethod(message_descriptor, cls)
1026  _AddHasFieldMethod(message_descriptor, cls)
1027  _AddClearFieldMethod(message_descriptor, cls)
1028  if message_descriptor.is_extendable:
1029    _AddClearExtensionMethod(cls)
1030    _AddHasExtensionMethod(cls)
1031  _AddClearMethod(message_descriptor, cls)
1032  _AddEqualsMethod(message_descriptor, cls)
1033  _AddStrMethod(message_descriptor, cls)
1034  _AddUnicodeMethod(message_descriptor, cls)
1035  _AddSetListenerMethod(cls)
1036  _AddByteSizeMethod(message_descriptor, cls)
1037  _AddSerializeToStringMethod(message_descriptor, cls)
1038  _AddSerializePartialToStringMethod(message_descriptor, cls)
1039  _AddMergeFromStringMethod(message_descriptor, cls)
1040  _AddIsInitializedMethod(message_descriptor, cls)
1041  _AddMergeFromMethod(cls)
1042  _AddWhichOneofMethod(message_descriptor, cls)
1043
1044def _AddPrivateHelperMethods(message_descriptor, cls):
1045  """Adds implementation of private helper methods to cls."""
1046
1047  def Modified(self):
1048    """Sets the _cached_byte_size_dirty bit to true,
1049    and propagates this to our listener iff this was a state change.
1050    """
1051
1052    # Note:  Some callers check _cached_byte_size_dirty before calling
1053    #   _Modified() as an extra optimization.  So, if this method is ever
1054    #   changed such that it does stuff even when _cached_byte_size_dirty is
1055    #   already true, the callers need to be updated.
1056    if not self._cached_byte_size_dirty:
1057      self._cached_byte_size_dirty = True
1058      self._listener_for_children.dirty = True
1059      self._is_present_in_parent = True
1060      self._listener.Modified()
1061
1062  def _UpdateOneofState(self, field):
1063    """Sets field as the active field in its containing oneof.
1064
1065    Will also delete currently active field in the oneof, if it is different
1066    from the argument. Does not mark the message as modified.
1067    """
1068    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1069    if other_field is not field:
1070      del self._fields[other_field]
1071      self._oneofs[field.containing_oneof] = field
1072
1073  cls._Modified = Modified
1074  cls.SetInParent = Modified
1075  cls._UpdateOneofState = _UpdateOneofState
1076
1077
1078class _Listener(object):
1079
1080  """MessageListener implementation that a parent message registers with its
1081  child message.
1082
1083  In order to support semantics like:
1084
1085    foo.bar.baz.qux = 23
1086    assert foo.HasField('bar')
1087
1088  ...child objects must have back references to their parents.
1089  This helper class is at the heart of this support.
1090  """
1091
1092  def __init__(self, parent_message):
1093    """Args:
1094      parent_message: The message whose _Modified() method we should call when
1095        we receive Modified() messages.
1096    """
1097    # This listener establishes a back reference from a child (contained) object
1098    # to its parent (containing) object.  We make this a weak reference to avoid
1099    # creating cyclic garbage when the client finishes with the 'parent' object
1100    # in the tree.
1101    if isinstance(parent_message, weakref.ProxyType):
1102      self._parent_message_weakref = parent_message
1103    else:
1104      self._parent_message_weakref = weakref.proxy(parent_message)
1105
1106    # As an optimization, we also indicate directly on the listener whether
1107    # or not the parent message is dirty.  This way we can avoid traversing
1108    # up the tree in the common case.
1109    self.dirty = False
1110
1111  def Modified(self):
1112    if self.dirty:
1113      return
1114    try:
1115      # Propagate the signal to our parents iff this is the first field set.
1116      self._parent_message_weakref._Modified()
1117    except ReferenceError:
1118      # We can get here if a client has kept a reference to a child object,
1119      # and is now setting a field on it, but the child's parent has been
1120      # garbage-collected.  This is not an error.
1121      pass
1122
1123
1124class _OneofListener(_Listener):
1125  """Special listener implementation for setting composite oneof fields."""
1126
1127  def __init__(self, parent_message, field):
1128    """Args:
1129      parent_message: The message whose _Modified() method we should call when
1130        we receive Modified() messages.
1131      field: The descriptor of the field being set in the parent message.
1132    """
1133    super(_OneofListener, self).__init__(parent_message)
1134    self._field = field
1135
1136  def Modified(self):
1137    """Also updates the state of the containing oneof in the parent message."""
1138    try:
1139      self._parent_message_weakref._UpdateOneofState(self._field)
1140      super(_OneofListener, self).Modified()
1141    except ReferenceError:
1142      pass
1143
1144
1145# TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1146# TODO(robinson): Unify error handling of "unknown extension" crap.
1147# TODO(robinson): Support iteritems()-style iteration over all
1148# extensions with the "has" bits turned on?
1149class _ExtensionDict(object):
1150
1151  """Dict-like container for supporting an indexable "Extensions"
1152  field on proto instances.
1153
1154  Note that in all cases we expect extension handles to be
1155  FieldDescriptors.
1156  """
1157
1158  def __init__(self, extended_message):
1159    """extended_message: Message instance for which we are the Extensions dict.
1160    """
1161
1162    self._extended_message = extended_message
1163
1164  def __getitem__(self, extension_handle):
1165    """Returns the current value of the given extension handle."""
1166
1167    _VerifyExtensionHandle(self._extended_message, extension_handle)
1168
1169    result = self._extended_message._fields.get(extension_handle)
1170    if result is not None:
1171      return result
1172
1173    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1174      result = extension_handle._default_constructor(self._extended_message)
1175    elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1176      result = extension_handle.message_type._concrete_class()
1177      try:
1178        result._SetListener(self._extended_message._listener_for_children)
1179      except ReferenceError:
1180        pass
1181    else:
1182      # Singular scalar -- just return the default without inserting into the
1183      # dict.
1184      return extension_handle.default_value
1185
1186    # Atomically check if another thread has preempted us and, if not, swap
1187    # in the new object we just created.  If someone has preempted us, we
1188    # take that object and discard ours.
1189    # WARNING:  We are relying on setdefault() being atomic.  This is true
1190    #   in CPython but we haven't investigated others.  This warning appears
1191    #   in several other locations in this file.
1192    result = self._extended_message._fields.setdefault(
1193        extension_handle, result)
1194
1195    return result
1196
1197  def __eq__(self, other):
1198    if not isinstance(other, self.__class__):
1199      return False
1200
1201    my_fields = self._extended_message.ListFields()
1202    other_fields = other._extended_message.ListFields()
1203
1204    # Get rid of non-extension fields.
1205    my_fields    = [ field for field in my_fields    if field.is_extension ]
1206    other_fields = [ field for field in other_fields if field.is_extension ]
1207
1208    return my_fields == other_fields
1209
1210  def __ne__(self, other):
1211    return not self == other
1212
1213  def __hash__(self):
1214    raise TypeError('unhashable object')
1215
1216  # Note that this is only meaningful for non-repeated, scalar extension
1217  # fields.  Note also that we may have to call _Modified() when we do
1218  # successfully set a field this way, to set any necssary "has" bits in the
1219  # ancestors of the extended message.
1220  def __setitem__(self, extension_handle, value):
1221    """If extension_handle specifies a non-repeated, scalar extension
1222    field, sets the value of that field.
1223    """
1224
1225    _VerifyExtensionHandle(self._extended_message, extension_handle)
1226
1227    if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1228        extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1229      raise TypeError(
1230          'Cannot assign to extension "%s" because it is a repeated or '
1231          'composite type.' % extension_handle.full_name)
1232
1233    # It's slightly wasteful to lookup the type checker each time,
1234    # but we expect this to be a vanishingly uncommon case anyway.
1235    type_checker = type_checkers.GetTypeChecker(
1236        extension_handle)
1237    # pylint: disable=protected-access
1238    self._extended_message._fields[extension_handle] = (
1239      type_checker.CheckValue(value))
1240    self._extended_message._Modified()
1241
1242  def _FindExtensionByName(self, name):
1243    """Tries to find a known extension with the specified name.
1244
1245    Args:
1246      name: Extension full name.
1247
1248    Returns:
1249      Extension field descriptor.
1250    """
1251    return self._extended_message._extensions_by_name.get(name, None)
1252