1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains well known classes.
32
33This files defines well known classes which need extra maintenance including:
34  - Any
35  - Duration
36  - FieldMask
37  - Struct
38  - Timestamp
39"""
40
41__author__ = 'jieluo@google.com (Jie Luo)'
42
43from datetime import datetime
44from datetime import timedelta
45import six
46
47from google.protobuf.descriptor import FieldDescriptor
48
49_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
50_NANOS_PER_SECOND = 1000000000
51_NANOS_PER_MILLISECOND = 1000000
52_NANOS_PER_MICROSECOND = 1000
53_MILLIS_PER_SECOND = 1000
54_MICROS_PER_SECOND = 1000000
55_SECONDS_PER_DAY = 24 * 3600
56
57
58class Error(Exception):
59  """Top-level module error."""
60
61
62class ParseError(Error):
63  """Thrown in case of parsing error."""
64
65
66class Any(object):
67  """Class for Any Message type."""
68
69  def Pack(self, msg, type_url_prefix='type.googleapis.com/'):
70    """Packs the specified message into current Any message."""
71    if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
72      self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
73    else:
74      self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
75    self.value = msg.SerializeToString()
76
77  def Unpack(self, msg):
78    """Unpacks the current Any message into specified message."""
79    descriptor = msg.DESCRIPTOR
80    if not self.Is(descriptor):
81      return False
82    msg.ParseFromString(self.value)
83    return True
84
85  def TypeName(self):
86    """Returns the protobuf type name of the inner message."""
87    # Only last part is to be used: b/25630112
88    return self.type_url.split('/')[-1]
89
90  def Is(self, descriptor):
91    """Checks if this Any represents the given protobuf type."""
92    return self.TypeName() == descriptor.full_name
93
94
95class Timestamp(object):
96  """Class for Timestamp message type."""
97
98  def ToJsonString(self):
99    """Converts Timestamp to RFC 3339 date string format.
100
101    Returns:
102      A string converted from timestamp. The string is always Z-normalized
103      and uses 3, 6 or 9 fractional digits as required to represent the
104      exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
105    """
106    nanos = self.nanos % _NANOS_PER_SECOND
107    total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
108    seconds = total_sec % _SECONDS_PER_DAY
109    days = (total_sec - seconds) // _SECONDS_PER_DAY
110    dt = datetime(1970, 1, 1) + timedelta(days, seconds)
111
112    result = dt.isoformat()
113    if (nanos % 1e9) == 0:
114      # If there are 0 fractional digits, the fractional
115      # point '.' should be omitted when serializing.
116      return result + 'Z'
117    if (nanos % 1e6) == 0:
118      # Serialize 3 fractional digits.
119      return result + '.%03dZ' % (nanos / 1e6)
120    if (nanos % 1e3) == 0:
121      # Serialize 6 fractional digits.
122      return result + '.%06dZ' % (nanos / 1e3)
123    # Serialize 9 fractional digits.
124    return result + '.%09dZ' % nanos
125
126  def FromJsonString(self, value):
127    """Parse a RFC 3339 date string format to Timestamp.
128
129    Args:
130      value: A date string. Any fractional digits (or none) and any offset are
131          accepted as long as they fit into nano-seconds precision.
132          Example of accepted format: '1972-01-01T10:00:20.021-05:00'
133
134    Raises:
135      ParseError: On parsing problems.
136    """
137    timezone_offset = value.find('Z')
138    if timezone_offset == -1:
139      timezone_offset = value.find('+')
140    if timezone_offset == -1:
141      timezone_offset = value.rfind('-')
142    if timezone_offset == -1:
143      raise ParseError(
144          'Failed to parse timestamp: missing valid timezone offset.')
145    time_value = value[0:timezone_offset]
146    # Parse datetime and nanos.
147    point_position = time_value.find('.')
148    if point_position == -1:
149      second_value = time_value
150      nano_value = ''
151    else:
152      second_value = time_value[:point_position]
153      nano_value = time_value[point_position + 1:]
154    date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
155    td = date_object - datetime(1970, 1, 1)
156    seconds = td.seconds + td.days * _SECONDS_PER_DAY
157    if len(nano_value) > 9:
158      raise ParseError(
159          'Failed to parse Timestamp: nanos {0} more than '
160          '9 fractional digits.'.format(nano_value))
161    if nano_value:
162      nanos = round(float('0.' + nano_value) * 1e9)
163    else:
164      nanos = 0
165    # Parse timezone offsets.
166    if value[timezone_offset] == 'Z':
167      if len(value) != timezone_offset + 1:
168        raise ParseError('Failed to parse timestamp: invalid trailing'
169                         ' data {0}.'.format(value))
170    else:
171      timezone = value[timezone_offset:]
172      pos = timezone.find(':')
173      if pos == -1:
174        raise ParseError(
175            'Invalid timezone offset value: {0}.'.format(timezone))
176      if timezone[0] == '+':
177        seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
178      else:
179        seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
180    # Set seconds and nanos
181    self.seconds = int(seconds)
182    self.nanos = int(nanos)
183
184  def GetCurrentTime(self):
185    """Get the current UTC into Timestamp."""
186    self.FromDatetime(datetime.utcnow())
187
188  def ToNanoseconds(self):
189    """Converts Timestamp to nanoseconds since epoch."""
190    return self.seconds * _NANOS_PER_SECOND + self.nanos
191
192  def ToMicroseconds(self):
193    """Converts Timestamp to microseconds since epoch."""
194    return (self.seconds * _MICROS_PER_SECOND +
195            self.nanos // _NANOS_PER_MICROSECOND)
196
197  def ToMilliseconds(self):
198    """Converts Timestamp to milliseconds since epoch."""
199    return (self.seconds * _MILLIS_PER_SECOND +
200            self.nanos // _NANOS_PER_MILLISECOND)
201
202  def ToSeconds(self):
203    """Converts Timestamp to seconds since epoch."""
204    return self.seconds
205
206  def FromNanoseconds(self, nanos):
207    """Converts nanoseconds since epoch to Timestamp."""
208    self.seconds = nanos // _NANOS_PER_SECOND
209    self.nanos = nanos % _NANOS_PER_SECOND
210
211  def FromMicroseconds(self, micros):
212    """Converts microseconds since epoch to Timestamp."""
213    self.seconds = micros // _MICROS_PER_SECOND
214    self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
215
216  def FromMilliseconds(self, millis):
217    """Converts milliseconds since epoch to Timestamp."""
218    self.seconds = millis // _MILLIS_PER_SECOND
219    self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
220
221  def FromSeconds(self, seconds):
222    """Converts seconds since epoch to Timestamp."""
223    self.seconds = seconds
224    self.nanos = 0
225
226  def ToDatetime(self):
227    """Converts Timestamp to datetime."""
228    return datetime.utcfromtimestamp(
229        self.seconds + self.nanos / float(_NANOS_PER_SECOND))
230
231  def FromDatetime(self, dt):
232    """Converts datetime to Timestamp."""
233    td = dt - datetime(1970, 1, 1)
234    self.seconds = td.seconds + td.days * _SECONDS_PER_DAY
235    self.nanos = td.microseconds * _NANOS_PER_MICROSECOND
236
237
238class Duration(object):
239  """Class for Duration message type."""
240
241  def ToJsonString(self):
242    """Converts Duration to string format.
243
244    Returns:
245      A string converted from self. The string format will contains
246      3, 6, or 9 fractional digits depending on the precision required to
247      represent the exact Duration value. For example: "1s", "1.010s",
248      "1.000000100s", "-3.100s"
249    """
250    if self.seconds < 0 or self.nanos < 0:
251      result = '-'
252      seconds = - self.seconds + int((0 - self.nanos) // 1e9)
253      nanos = (0 - self.nanos) % 1e9
254    else:
255      result = ''
256      seconds = self.seconds + int(self.nanos // 1e9)
257      nanos = self.nanos % 1e9
258    result += '%d' % seconds
259    if (nanos % 1e9) == 0:
260      # If there are 0 fractional digits, the fractional
261      # point '.' should be omitted when serializing.
262      return result + 's'
263    if (nanos % 1e6) == 0:
264      # Serialize 3 fractional digits.
265      return result + '.%03ds' % (nanos / 1e6)
266    if (nanos % 1e3) == 0:
267      # Serialize 6 fractional digits.
268      return result + '.%06ds' % (nanos / 1e3)
269    # Serialize 9 fractional digits.
270    return result + '.%09ds' % nanos
271
272  def FromJsonString(self, value):
273    """Converts a string to Duration.
274
275    Args:
276      value: A string to be converted. The string must end with 's'. Any
277          fractional digits (or none) are accepted as long as they fit into
278          precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
279
280    Raises:
281      ParseError: On parsing problems.
282    """
283    if len(value) < 1 or value[-1] != 's':
284      raise ParseError(
285          'Duration must end with letter "s": {0}.'.format(value))
286    try:
287      pos = value.find('.')
288      if pos == -1:
289        self.seconds = int(value[:-1])
290        self.nanos = 0
291      else:
292        self.seconds = int(value[:pos])
293        if value[0] == '-':
294          self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
295        else:
296          self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
297    except ValueError:
298      raise ParseError(
299          'Couldn\'t parse duration: {0}.'.format(value))
300
301  def ToNanoseconds(self):
302    """Converts a Duration to nanoseconds."""
303    return self.seconds * _NANOS_PER_SECOND + self.nanos
304
305  def ToMicroseconds(self):
306    """Converts a Duration to microseconds."""
307    micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
308    return self.seconds * _MICROS_PER_SECOND + micros
309
310  def ToMilliseconds(self):
311    """Converts a Duration to milliseconds."""
312    millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
313    return self.seconds * _MILLIS_PER_SECOND + millis
314
315  def ToSeconds(self):
316    """Converts a Duration to seconds."""
317    return self.seconds
318
319  def FromNanoseconds(self, nanos):
320    """Converts nanoseconds to Duration."""
321    self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
322                            nanos % _NANOS_PER_SECOND)
323
324  def FromMicroseconds(self, micros):
325    """Converts microseconds to Duration."""
326    self._NormalizeDuration(
327        micros // _MICROS_PER_SECOND,
328        (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
329
330  def FromMilliseconds(self, millis):
331    """Converts milliseconds to Duration."""
332    self._NormalizeDuration(
333        millis // _MILLIS_PER_SECOND,
334        (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
335
336  def FromSeconds(self, seconds):
337    """Converts seconds to Duration."""
338    self.seconds = seconds
339    self.nanos = 0
340
341  def ToTimedelta(self):
342    """Converts Duration to timedelta."""
343    return timedelta(
344        seconds=self.seconds, microseconds=_RoundTowardZero(
345            self.nanos, _NANOS_PER_MICROSECOND))
346
347  def FromTimedelta(self, td):
348    """Convertd timedelta to Duration."""
349    self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
350                            td.microseconds * _NANOS_PER_MICROSECOND)
351
352  def _NormalizeDuration(self, seconds, nanos):
353    """Set Duration by seconds and nonas."""
354    # Force nanos to be negative if the duration is negative.
355    if seconds < 0 and nanos > 0:
356      seconds += 1
357      nanos -= _NANOS_PER_SECOND
358    self.seconds = seconds
359    self.nanos = nanos
360
361
362def _RoundTowardZero(value, divider):
363  """Truncates the remainder part after division."""
364  # For some languanges, the sign of the remainder is implementation
365  # dependent if any of the operands is negative. Here we enforce
366  # "rounded toward zero" semantics. For example, for (-5) / 2 an
367  # implementation may give -3 as the result with the remainder being
368  # 1. This function ensures we always return -2 (closer to zero).
369  result = value // divider
370  remainder = value % divider
371  if result < 0 and remainder > 0:
372    return result + 1
373  else:
374    return result
375
376
377class FieldMask(object):
378  """Class for FieldMask message type."""
379
380  def ToJsonString(self):
381    """Converts FieldMask to string according to proto3 JSON spec."""
382    return ','.join(self.paths)
383
384  def FromJsonString(self, value):
385    """Converts string to FieldMask according to proto3 JSON spec."""
386    self.Clear()
387    for path in value.split(','):
388      self.paths.append(path)
389
390  def IsValidForDescriptor(self, message_descriptor):
391    """Checks whether the FieldMask is valid for Message Descriptor."""
392    for path in self.paths:
393      if not _IsValidPath(message_descriptor, path):
394        return False
395    return True
396
397  def AllFieldsFromDescriptor(self, message_descriptor):
398    """Gets all direct fields of Message Descriptor to FieldMask."""
399    self.Clear()
400    for field in message_descriptor.fields:
401      self.paths.append(field.name)
402
403  def CanonicalFormFromMask(self, mask):
404    """Converts a FieldMask to the canonical form.
405
406    Removes paths that are covered by another path. For example,
407    "foo.bar" is covered by "foo" and will be removed if "foo"
408    is also in the FieldMask. Then sorts all paths in alphabetical order.
409
410    Args:
411      mask: The original FieldMask to be converted.
412    """
413    tree = _FieldMaskTree(mask)
414    tree.ToFieldMask(self)
415
416  def Union(self, mask1, mask2):
417    """Merges mask1 and mask2 into this FieldMask."""
418    _CheckFieldMaskMessage(mask1)
419    _CheckFieldMaskMessage(mask2)
420    tree = _FieldMaskTree(mask1)
421    tree.MergeFromFieldMask(mask2)
422    tree.ToFieldMask(self)
423
424  def Intersect(self, mask1, mask2):
425    """Intersects mask1 and mask2 into this FieldMask."""
426    _CheckFieldMaskMessage(mask1)
427    _CheckFieldMaskMessage(mask2)
428    tree = _FieldMaskTree(mask1)
429    intersection = _FieldMaskTree()
430    for path in mask2.paths:
431      tree.IntersectPath(path, intersection)
432    intersection.ToFieldMask(self)
433
434  def MergeMessage(
435      self, source, destination,
436      replace_message_field=False, replace_repeated_field=False):
437    """Merges fields specified in FieldMask from source to destination.
438
439    Args:
440      source: Source message.
441      destination: The destination message to be merged into.
442      replace_message_field: Replace message field if True. Merge message
443          field if False.
444      replace_repeated_field: Replace repeated field if True. Append
445          elements of repeated field if False.
446    """
447    tree = _FieldMaskTree(self)
448    tree.MergeMessage(
449        source, destination, replace_message_field, replace_repeated_field)
450
451
452def _IsValidPath(message_descriptor, path):
453  """Checks whether the path is valid for Message Descriptor."""
454  parts = path.split('.')
455  last = parts.pop()
456  for name in parts:
457    field = message_descriptor.fields_by_name[name]
458    if (field is None or
459        field.label == FieldDescriptor.LABEL_REPEATED or
460        field.type != FieldDescriptor.TYPE_MESSAGE):
461      return False
462    message_descriptor = field.message_type
463  return last in message_descriptor.fields_by_name
464
465
466def _CheckFieldMaskMessage(message):
467  """Raises ValueError if message is not a FieldMask."""
468  message_descriptor = message.DESCRIPTOR
469  if (message_descriptor.name != 'FieldMask' or
470      message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
471    raise ValueError('Message {0} is not a FieldMask.'.format(
472        message_descriptor.full_name))
473
474
475class _FieldMaskTree(object):
476  """Represents a FieldMask in a tree structure.
477
478  For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
479  the FieldMaskTree will be:
480      [_root] -+- foo -+- bar
481            |       |
482            |       +- baz
483            |
484            +- bar --- baz
485  In the tree, each leaf node represents a field path.
486  """
487
488  def __init__(self, field_mask=None):
489    """Initializes the tree by FieldMask."""
490    self._root = {}
491    if field_mask:
492      self.MergeFromFieldMask(field_mask)
493
494  def MergeFromFieldMask(self, field_mask):
495    """Merges a FieldMask to the tree."""
496    for path in field_mask.paths:
497      self.AddPath(path)
498
499  def AddPath(self, path):
500    """Adds a field path into the tree.
501
502    If the field path to add is a sub-path of an existing field path
503    in the tree (i.e., a leaf node), it means the tree already matches
504    the given path so nothing will be added to the tree. If the path
505    matches an existing non-leaf node in the tree, that non-leaf node
506    will be turned into a leaf node with all its children removed because
507    the path matches all the node's children. Otherwise, a new path will
508    be added.
509
510    Args:
511      path: The field path to add.
512    """
513    node = self._root
514    for name in path.split('.'):
515      if name not in node:
516        node[name] = {}
517      elif not node[name]:
518        # Pre-existing empty node implies we already have this entire tree.
519        return
520      node = node[name]
521    # Remove any sub-trees we might have had.
522    node.clear()
523
524  def ToFieldMask(self, field_mask):
525    """Converts the tree to a FieldMask."""
526    field_mask.Clear()
527    _AddFieldPaths(self._root, '', field_mask)
528
529  def IntersectPath(self, path, intersection):
530    """Calculates the intersection part of a field path with this tree.
531
532    Args:
533      path: The field path to calculates.
534      intersection: The out tree to record the intersection part.
535    """
536    node = self._root
537    for name in path.split('.'):
538      if name not in node:
539        return
540      elif not node[name]:
541        intersection.AddPath(path)
542        return
543      node = node[name]
544    intersection.AddLeafNodes(path, node)
545
546  def AddLeafNodes(self, prefix, node):
547    """Adds leaf nodes begin with prefix to this tree."""
548    if not node:
549      self.AddPath(prefix)
550    for name in node:
551      child_path = prefix + '.' + name
552      self.AddLeafNodes(child_path, node[name])
553
554  def MergeMessage(
555      self, source, destination,
556      replace_message, replace_repeated):
557    """Merge all fields specified by this tree from source to destination."""
558    _MergeMessage(
559        self._root, source, destination, replace_message, replace_repeated)
560
561
562def _StrConvert(value):
563  """Converts value to str if it is not."""
564  # This file is imported by c extension and some methods like ClearField
565  # requires string for the field name. py2/py3 has different text
566  # type and may use unicode.
567  if not isinstance(value, str):
568    return value.encode('utf-8')
569  return value
570
571
572def _MergeMessage(
573    node, source, destination, replace_message, replace_repeated):
574  """Merge all fields specified by a sub-tree from source to destination."""
575  source_descriptor = source.DESCRIPTOR
576  for name in node:
577    child = node[name]
578    field = source_descriptor.fields_by_name[name]
579    if field is None:
580      raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
581          name, source_descriptor.full_name))
582    if child:
583      # Sub-paths are only allowed for singular message fields.
584      if (field.label == FieldDescriptor.LABEL_REPEATED or
585          field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
586        raise ValueError('Error: Field {0} in message {1} is not a singular '
587                         'message field and cannot have sub-fields.'.format(
588                             name, source_descriptor.full_name))
589      _MergeMessage(
590          child, getattr(source, name), getattr(destination, name),
591          replace_message, replace_repeated)
592      continue
593    if field.label == FieldDescriptor.LABEL_REPEATED:
594      if replace_repeated:
595        destination.ClearField(_StrConvert(name))
596      repeated_source = getattr(source, name)
597      repeated_destination = getattr(destination, name)
598      if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
599        for item in repeated_source:
600          repeated_destination.add().MergeFrom(item)
601      else:
602        repeated_destination.extend(repeated_source)
603    else:
604      if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
605        if replace_message:
606          destination.ClearField(_StrConvert(name))
607        if source.HasField(name):
608          getattr(destination, name).MergeFrom(getattr(source, name))
609      else:
610        setattr(destination, name, getattr(source, name))
611
612
613def _AddFieldPaths(node, prefix, field_mask):
614  """Adds the field paths descended from node to field_mask."""
615  if not node:
616    field_mask.paths.append(prefix)
617    return
618  for name in sorted(node):
619    if prefix:
620      child_path = prefix + '.' + name
621    else:
622      child_path = name
623    _AddFieldPaths(node[name], child_path, field_mask)
624
625
626_INT_OR_FLOAT = six.integer_types + (float,)
627
628
629def _SetStructValue(struct_value, value):
630  if value is None:
631    struct_value.null_value = 0
632  elif isinstance(value, bool):
633    # Note: this check must come before the number check because in Python
634    # True and False are also considered numbers.
635    struct_value.bool_value = value
636  elif isinstance(value, six.string_types):
637    struct_value.string_value = value
638  elif isinstance(value, _INT_OR_FLOAT):
639    struct_value.number_value = value
640  else:
641    raise ValueError('Unexpected type')
642
643
644def _GetStructValue(struct_value):
645  which = struct_value.WhichOneof('kind')
646  if which == 'struct_value':
647    return struct_value.struct_value
648  elif which == 'null_value':
649    return None
650  elif which == 'number_value':
651    return struct_value.number_value
652  elif which == 'string_value':
653    return struct_value.string_value
654  elif which == 'bool_value':
655    return struct_value.bool_value
656  elif which == 'list_value':
657    return struct_value.list_value
658  elif which is None:
659    raise ValueError('Value not set')
660
661
662class Struct(object):
663  """Class for Struct message type."""
664
665  __slots__ = []
666
667  def __getitem__(self, key):
668    return _GetStructValue(self.fields[key])
669
670  def __setitem__(self, key, value):
671    _SetStructValue(self.fields[key], value)
672
673  def get_or_create_list(self, key):
674    """Returns a list for this key, creating if it didn't exist already."""
675    return self.fields[key].list_value
676
677  def get_or_create_struct(self, key):
678    """Returns a struct for this key, creating if it didn't exist already."""
679    return self.fields[key].struct_value
680
681  # TODO(haberman): allow constructing/merging from dict.
682
683
684class ListValue(object):
685  """Class for ListValue message type."""
686
687  def __len__(self):
688    return len(self.values)
689
690  def append(self, value):
691    _SetStructValue(self.values.add(), value)
692
693  def extend(self, elem_seq):
694    for value in elem_seq:
695      self.append(value)
696
697  def __getitem__(self, index):
698    """Retrieves item by the specified index."""
699    return _GetStructValue(self.values.__getitem__(index))
700
701  def __setitem__(self, index, value):
702    _SetStructValue(self.values.__getitem__(index), value)
703
704  def items(self):
705    for i in range(len(self)):
706      yield self[i]
707
708  def add_struct(self):
709    """Appends and returns a struct value as the next value in the list."""
710    return self.values.add().struct_value
711
712  def add_list(self):
713    """Appends and returns a list value as the next value in the list."""
714    return self.values.add().list_value
715
716
717WKTBASES = {
718    'google.protobuf.Any': Any,
719    'google.protobuf.Duration': Duration,
720    'google.protobuf.FieldMask': FieldMask,
721    'google.protobuf.ListValue': ListValue,
722    'google.protobuf.Struct': Struct,
723    'google.protobuf.Timestamp': Timestamp,
724}
725