1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31#PY25 compatible for GAE.
32#
33# Copyright 2009 Google Inc. All Rights Reserved.
34
35"""Code for decoding protocol buffer primitives.
36
37This code is very similar to encoder.py -- read the docs for that module first.
38
39A "decoder" is a function with the signature:
40  Decode(buffer, pos, end, message, field_dict)
41The arguments are:
42  buffer:     The string containing the encoded message.
43  pos:        The current position in the string.
44  end:        The position in the string where the current message ends.  May be
45              less than len(buffer) if we're reading a sub-message.
46  message:    The message object into which we're parsing.
47  field_dict: message._fields (avoids a hashtable lookup).
48The decoder reads the field and stores it into field_dict, returning the new
49buffer position.  A decoder for a repeated field may proactively decode all of
50the elements of that field, if they appear consecutively.
51
52Note that decoders may throw any of the following:
53  IndexError:  Indicates a truncated message.
54  struct.error:  Unpacking of a fixed-width field failed.
55  message.DecodeError:  Other errors.
56
57Decoders are expected to raise an exception if they are called with pos > end.
58This allows callers to be lax about bounds checking:  it's fineto read past
59"end" as long as you are sure that someone else will notice and throw an
60exception later on.
61
62Something up the call stack is expected to catch IndexError and struct.error
63and convert them to message.DecodeError.
64
65Decoders are constructed using decoder constructors with the signature:
66  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
67The arguments are:
68  field_number:  The field number of the field we want to decode.
69  is_repeated:   Is the field a repeated field? (bool)
70  is_packed:     Is the field a packed field? (bool)
71  key:           The key to use when looking up the field within field_dict.
72                 (This is actually the FieldDescriptor but nothing in this
73                 file should depend on that.)
74  new_default:   A function which takes a message object as a parameter and
75                 returns a new instance of the default value for this field.
76                 (This is called for repeated fields and sub-messages, when an
77                 instance does not already exist.)
78
79As with encoders, we define a decoder constructor for every type of field.
80Then, for every field of every message class we construct an actual decoder.
81That decoder goes into a dict indexed by tag, so when we decode a message
82we repeatedly read a tag, look up the corresponding decoder, and invoke it.
83"""
84
85__author__ = 'kenton@google.com (Kenton Varda)'
86
87import struct
88import sys  ##PY25
89_PY2 = sys.version_info[0] < 3  ##PY25
90from google.protobuf.internal import encoder
91from google.protobuf.internal import wire_format
92from google.protobuf import message
93
94
95# This will overflow and thus become IEEE-754 "infinity".  We would use
96# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
97_POS_INF = 1e10000
98_NEG_INF = -_POS_INF
99_NAN = _POS_INF * 0
100
101
102# This is not for optimization, but rather to avoid conflicts with local
103# variables named "message".
104_DecodeError = message.DecodeError
105
106
107def _VarintDecoder(mask, result_type):
108  """Return an encoder for a basic varint value (does not include tag).
109
110  Decoded values will be bitwise-anded with the given mask before being
111  returned, e.g. to limit them to 32 bits.  The returned decoder does not
112  take the usual "end" parameter -- the caller is expected to do bounds checking
113  after the fact (often the caller can defer such checking until later).  The
114  decoder returns a (value, new_pos) pair.
115  """
116
117  local_ord = ord
118  py2 = _PY2  ##PY25
119##!PY25  py2 = str is bytes
120  def DecodeVarint(buffer, pos):
121    result = 0
122    shift = 0
123    while 1:
124      b = local_ord(buffer[pos]) if py2 else buffer[pos]
125      result |= ((b & 0x7f) << shift)
126      pos += 1
127      if not (b & 0x80):
128        result &= mask
129        result = result_type(result)
130        return (result, pos)
131      shift += 7
132      if shift >= 64:
133        raise _DecodeError('Too many bytes when decoding varint.')
134  return DecodeVarint
135
136
137def _SignedVarintDecoder(mask, result_type):
138  """Like _VarintDecoder() but decodes signed values."""
139
140  local_ord = ord
141  py2 = _PY2  ##PY25
142##!PY25  py2 = str is bytes
143  def DecodeVarint(buffer, pos):
144    result = 0
145    shift = 0
146    while 1:
147      b = local_ord(buffer[pos]) if py2 else buffer[pos]
148      result |= ((b & 0x7f) << shift)
149      pos += 1
150      if not (b & 0x80):
151        if result > 0x7fffffffffffffff:
152          result -= (1 << 64)
153          result |= ~mask
154        else:
155          result &= mask
156        result = result_type(result)
157        return (result, pos)
158      shift += 7
159      if shift >= 64:
160        raise _DecodeError('Too many bytes when decoding varint.')
161  return DecodeVarint
162
163# We force 32-bit values to int and 64-bit values to long to make
164# alternate implementations where the distinction is more significant
165# (e.g. the C++ implementation) simpler.
166
167_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
168_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
169
170# Use these versions for values which must be limited to 32 bits.
171_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
172_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
173
174
175def ReadTag(buffer, pos):
176  """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
177
178  We return the raw bytes of the tag rather than decoding them.  The raw
179  bytes can then be used to look up the proper decoder.  This effectively allows
180  us to trade some work that would be done in pure-python (decoding a varint)
181  for work that is done in C (searching for a byte string in a hash table).
182  In a low-level language it would be much cheaper to decode the varint and
183  use that, but not in Python.
184  """
185
186  py2 = _PY2  ##PY25
187##!PY25  py2 = str is bytes
188  start = pos
189  while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80:
190    pos += 1
191  pos += 1
192  return (buffer[start:pos], pos)
193
194
195# --------------------------------------------------------------------
196
197
198def _SimpleDecoder(wire_type, decode_value):
199  """Return a constructor for a decoder for fields of a particular type.
200
201  Args:
202      wire_type:  The field's wire type.
203      decode_value:  A function which decodes an individual value, e.g.
204        _DecodeVarint()
205  """
206
207  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
208    if is_packed:
209      local_DecodeVarint = _DecodeVarint
210      def DecodePackedField(buffer, pos, end, message, field_dict):
211        value = field_dict.get(key)
212        if value is None:
213          value = field_dict.setdefault(key, new_default(message))
214        (endpoint, pos) = local_DecodeVarint(buffer, pos)
215        endpoint += pos
216        if endpoint > end:
217          raise _DecodeError('Truncated message.')
218        while pos < endpoint:
219          (element, pos) = decode_value(buffer, pos)
220          value.append(element)
221        if pos > endpoint:
222          del value[-1]   # Discard corrupt value.
223          raise _DecodeError('Packed element was truncated.')
224        return pos
225      return DecodePackedField
226    elif is_repeated:
227      tag_bytes = encoder.TagBytes(field_number, wire_type)
228      tag_len = len(tag_bytes)
229      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
230        value = field_dict.get(key)
231        if value is None:
232          value = field_dict.setdefault(key, new_default(message))
233        while 1:
234          (element, new_pos) = decode_value(buffer, pos)
235          value.append(element)
236          # Predict that the next tag is another copy of the same repeated
237          # field.
238          pos = new_pos + tag_len
239          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
240            # Prediction failed.  Return.
241            if new_pos > end:
242              raise _DecodeError('Truncated message.')
243            return new_pos
244      return DecodeRepeatedField
245    else:
246      def DecodeField(buffer, pos, end, message, field_dict):
247        (field_dict[key], pos) = decode_value(buffer, pos)
248        if pos > end:
249          del field_dict[key]  # Discard corrupt value.
250          raise _DecodeError('Truncated message.')
251        return pos
252      return DecodeField
253
254  return SpecificDecoder
255
256
257def _ModifiedDecoder(wire_type, decode_value, modify_value):
258  """Like SimpleDecoder but additionally invokes modify_value on every value
259  before storing it.  Usually modify_value is ZigZagDecode.
260  """
261
262  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
263  # not enough to make a significant difference.
264
265  def InnerDecode(buffer, pos):
266    (result, new_pos) = decode_value(buffer, pos)
267    return (modify_value(result), new_pos)
268  return _SimpleDecoder(wire_type, InnerDecode)
269
270
271def _StructPackDecoder(wire_type, format):
272  """Return a constructor for a decoder for a fixed-width field.
273
274  Args:
275      wire_type:  The field's wire type.
276      format:  The format string to pass to struct.unpack().
277  """
278
279  value_size = struct.calcsize(format)
280  local_unpack = struct.unpack
281
282  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
283  # not enough to make a significant difference.
284
285  # Note that we expect someone up-stack to catch struct.error and convert
286  # it to _DecodeError -- this way we don't have to set up exception-
287  # handling blocks every time we parse one value.
288
289  def InnerDecode(buffer, pos):
290    new_pos = pos + value_size
291    result = local_unpack(format, buffer[pos:new_pos])[0]
292    return (result, new_pos)
293  return _SimpleDecoder(wire_type, InnerDecode)
294
295
296def _FloatDecoder():
297  """Returns a decoder for a float field.
298
299  This code works around a bug in struct.unpack for non-finite 32-bit
300  floating-point values.
301  """
302
303  local_unpack = struct.unpack
304  b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1')  ##PY25
305
306  def InnerDecode(buffer, pos):
307    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
308    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
309    new_pos = pos + 4
310    float_bytes = buffer[pos:new_pos]
311
312    # If this value has all its exponent bits set, then it's non-finite.
313    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
314    # To avoid that, we parse it specially.
315    if ((float_bytes[3:4] in b('\x7F\xFF'))  ##PY25
316##!PY25    if ((float_bytes[3:4] in b'\x7F\xFF')
317        and (float_bytes[2:3] >= b('\x80'))):  ##PY25
318##!PY25        and (float_bytes[2:3] >= b'\x80')):
319      # If at least one significand bit is set...
320      if float_bytes[0:3] != b('\x00\x00\x80'):  ##PY25
321##!PY25      if float_bytes[0:3] != b'\x00\x00\x80':
322        return (_NAN, new_pos)
323      # If sign bit is set...
324      if float_bytes[3:4] == b('\xFF'):  ##PY25
325##!PY25      if float_bytes[3:4] == b'\xFF':
326        return (_NEG_INF, new_pos)
327      return (_POS_INF, new_pos)
328
329    # Note that we expect someone up-stack to catch struct.error and convert
330    # it to _DecodeError -- this way we don't have to set up exception-
331    # handling blocks every time we parse one value.
332    result = local_unpack('<f', float_bytes)[0]
333    return (result, new_pos)
334  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
335
336
337def _DoubleDecoder():
338  """Returns a decoder for a double field.
339
340  This code works around a bug in struct.unpack for not-a-number.
341  """
342
343  local_unpack = struct.unpack
344  b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1')  ##PY25
345
346  def InnerDecode(buffer, pos):
347    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
348    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
349    new_pos = pos + 8
350    double_bytes = buffer[pos:new_pos]
351
352    # If this value has all its exponent bits set and at least one significand
353    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
354    # as inf or -inf.  To avoid that, we treat it specially.
355##!PY25    if ((double_bytes[7:8] in b'\x7F\xFF')
356##!PY25        and (double_bytes[6:7] >= b'\xF0')
357##!PY25        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
358    if ((double_bytes[7:8] in b('\x7F\xFF'))  ##PY25
359        and (double_bytes[6:7] >= b('\xF0'))  ##PY25
360        and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))):  ##PY25
361      return (_NAN, new_pos)
362
363    # Note that we expect someone up-stack to catch struct.error and convert
364    # it to _DecodeError -- this way we don't have to set up exception-
365    # handling blocks every time we parse one value.
366    result = local_unpack('<d', double_bytes)[0]
367    return (result, new_pos)
368  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
369
370
371def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
372  enum_type = key.enum_type
373  if is_packed:
374    local_DecodeVarint = _DecodeVarint
375    def DecodePackedField(buffer, pos, end, message, field_dict):
376      value = field_dict.get(key)
377      if value is None:
378        value = field_dict.setdefault(key, new_default(message))
379      (endpoint, pos) = local_DecodeVarint(buffer, pos)
380      endpoint += pos
381      if endpoint > end:
382        raise _DecodeError('Truncated message.')
383      while pos < endpoint:
384        value_start_pos = pos
385        (element, pos) = _DecodeSignedVarint32(buffer, pos)
386        if element in enum_type.values_by_number:
387          value.append(element)
388        else:
389          if not message._unknown_fields:
390            message._unknown_fields = []
391          tag_bytes = encoder.TagBytes(field_number,
392                                       wire_format.WIRETYPE_VARINT)
393          message._unknown_fields.append(
394              (tag_bytes, buffer[value_start_pos:pos]))
395      if pos > endpoint:
396        if element in enum_type.values_by_number:
397          del value[-1]   # Discard corrupt value.
398        else:
399          del message._unknown_fields[-1]
400        raise _DecodeError('Packed element was truncated.')
401      return pos
402    return DecodePackedField
403  elif is_repeated:
404    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
405    tag_len = len(tag_bytes)
406    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
407      value = field_dict.get(key)
408      if value is None:
409        value = field_dict.setdefault(key, new_default(message))
410      while 1:
411        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
412        if element in enum_type.values_by_number:
413          value.append(element)
414        else:
415          if not message._unknown_fields:
416            message._unknown_fields = []
417          message._unknown_fields.append(
418              (tag_bytes, buffer[pos:new_pos]))
419        # Predict that the next tag is another copy of the same repeated
420        # field.
421        pos = new_pos + tag_len
422        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
423          # Prediction failed.  Return.
424          if new_pos > end:
425            raise _DecodeError('Truncated message.')
426          return new_pos
427    return DecodeRepeatedField
428  else:
429    def DecodeField(buffer, pos, end, message, field_dict):
430      value_start_pos = pos
431      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
432      if pos > end:
433        raise _DecodeError('Truncated message.')
434      if enum_value in enum_type.values_by_number:
435        field_dict[key] = enum_value
436      else:
437        if not message._unknown_fields:
438          message._unknown_fields = []
439        tag_bytes = encoder.TagBytes(field_number,
440                                     wire_format.WIRETYPE_VARINT)
441        message._unknown_fields.append(
442          (tag_bytes, buffer[value_start_pos:pos]))
443      return pos
444    return DecodeField
445
446
447# --------------------------------------------------------------------
448
449
450Int32Decoder = _SimpleDecoder(
451    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
452
453Int64Decoder = _SimpleDecoder(
454    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
455
456UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
457UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
458
459SInt32Decoder = _ModifiedDecoder(
460    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
461SInt64Decoder = _ModifiedDecoder(
462    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
463
464# Note that Python conveniently guarantees that when using the '<' prefix on
465# formats, they will also have the same size across all platforms (as opposed
466# to without the prefix, where their sizes depend on the C compiler's basic
467# type sizes).
468Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
469Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
470SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
471SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
472FloatDecoder = _FloatDecoder()
473DoubleDecoder = _DoubleDecoder()
474
475BoolDecoder = _ModifiedDecoder(
476    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
477
478
479def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
480  """Returns a decoder for a string field."""
481
482  local_DecodeVarint = _DecodeVarint
483  local_unicode = unicode
484
485  def _ConvertToUnicode(byte_str):
486    try:
487      return local_unicode(byte_str, 'utf-8')
488    except UnicodeDecodeError, e:
489      # add more information to the error message and re-raise it.
490      e.reason = '%s in field: %s' % (e, key.full_name)
491      raise
492
493  assert not is_packed
494  if is_repeated:
495    tag_bytes = encoder.TagBytes(field_number,
496                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
497    tag_len = len(tag_bytes)
498    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
499      value = field_dict.get(key)
500      if value is None:
501        value = field_dict.setdefault(key, new_default(message))
502      while 1:
503        (size, pos) = local_DecodeVarint(buffer, pos)
504        new_pos = pos + size
505        if new_pos > end:
506          raise _DecodeError('Truncated string.')
507        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
508        # Predict that the next tag is another copy of the same repeated field.
509        pos = new_pos + tag_len
510        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
511          # Prediction failed.  Return.
512          return new_pos
513    return DecodeRepeatedField
514  else:
515    def DecodeField(buffer, pos, end, message, field_dict):
516      (size, pos) = local_DecodeVarint(buffer, pos)
517      new_pos = pos + size
518      if new_pos > end:
519        raise _DecodeError('Truncated string.')
520      field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
521      return new_pos
522    return DecodeField
523
524
525def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
526  """Returns a decoder for a bytes field."""
527
528  local_DecodeVarint = _DecodeVarint
529
530  assert not is_packed
531  if is_repeated:
532    tag_bytes = encoder.TagBytes(field_number,
533                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
534    tag_len = len(tag_bytes)
535    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
536      value = field_dict.get(key)
537      if value is None:
538        value = field_dict.setdefault(key, new_default(message))
539      while 1:
540        (size, pos) = local_DecodeVarint(buffer, pos)
541        new_pos = pos + size
542        if new_pos > end:
543          raise _DecodeError('Truncated string.')
544        value.append(buffer[pos:new_pos])
545        # Predict that the next tag is another copy of the same repeated field.
546        pos = new_pos + tag_len
547        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
548          # Prediction failed.  Return.
549          return new_pos
550    return DecodeRepeatedField
551  else:
552    def DecodeField(buffer, pos, end, message, field_dict):
553      (size, pos) = local_DecodeVarint(buffer, pos)
554      new_pos = pos + size
555      if new_pos > end:
556        raise _DecodeError('Truncated string.')
557      field_dict[key] = buffer[pos:new_pos]
558      return new_pos
559    return DecodeField
560
561
562def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
563  """Returns a decoder for a group field."""
564
565  end_tag_bytes = encoder.TagBytes(field_number,
566                                   wire_format.WIRETYPE_END_GROUP)
567  end_tag_len = len(end_tag_bytes)
568
569  assert not is_packed
570  if is_repeated:
571    tag_bytes = encoder.TagBytes(field_number,
572                                 wire_format.WIRETYPE_START_GROUP)
573    tag_len = len(tag_bytes)
574    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
575      value = field_dict.get(key)
576      if value is None:
577        value = field_dict.setdefault(key, new_default(message))
578      while 1:
579        value = field_dict.get(key)
580        if value is None:
581          value = field_dict.setdefault(key, new_default(message))
582        # Read sub-message.
583        pos = value.add()._InternalParse(buffer, pos, end)
584        # Read end tag.
585        new_pos = pos+end_tag_len
586        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
587          raise _DecodeError('Missing group end tag.')
588        # Predict that the next tag is another copy of the same repeated field.
589        pos = new_pos + tag_len
590        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
591          # Prediction failed.  Return.
592          return new_pos
593    return DecodeRepeatedField
594  else:
595    def DecodeField(buffer, pos, end, message, field_dict):
596      value = field_dict.get(key)
597      if value is None:
598        value = field_dict.setdefault(key, new_default(message))
599      # Read sub-message.
600      pos = value._InternalParse(buffer, pos, end)
601      # Read end tag.
602      new_pos = pos+end_tag_len
603      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
604        raise _DecodeError('Missing group end tag.')
605      return new_pos
606    return DecodeField
607
608
609def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
610  """Returns a decoder for a message field."""
611
612  local_DecodeVarint = _DecodeVarint
613
614  assert not is_packed
615  if is_repeated:
616    tag_bytes = encoder.TagBytes(field_number,
617                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
618    tag_len = len(tag_bytes)
619    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
620      value = field_dict.get(key)
621      if value is None:
622        value = field_dict.setdefault(key, new_default(message))
623      while 1:
624        value = field_dict.get(key)
625        if value is None:
626          value = field_dict.setdefault(key, new_default(message))
627        # Read length.
628        (size, pos) = local_DecodeVarint(buffer, pos)
629        new_pos = pos + size
630        if new_pos > end:
631          raise _DecodeError('Truncated message.')
632        # Read sub-message.
633        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
634          # The only reason _InternalParse would return early is if it
635          # encountered an end-group tag.
636          raise _DecodeError('Unexpected end-group tag.')
637        # Predict that the next tag is another copy of the same repeated field.
638        pos = new_pos + tag_len
639        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
640          # Prediction failed.  Return.
641          return new_pos
642    return DecodeRepeatedField
643  else:
644    def DecodeField(buffer, pos, end, message, field_dict):
645      value = field_dict.get(key)
646      if value is None:
647        value = field_dict.setdefault(key, new_default(message))
648      # Read length.
649      (size, pos) = local_DecodeVarint(buffer, pos)
650      new_pos = pos + size
651      if new_pos > end:
652        raise _DecodeError('Truncated message.')
653      # Read sub-message.
654      if value._InternalParse(buffer, pos, new_pos) != new_pos:
655        # The only reason _InternalParse would return early is if it encountered
656        # an end-group tag.
657        raise _DecodeError('Unexpected end-group tag.')
658      return new_pos
659    return DecodeField
660
661
662# --------------------------------------------------------------------
663
664MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
665
666def MessageSetItemDecoder(extensions_by_number):
667  """Returns a decoder for a MessageSet item.
668
669  The parameter is the _extensions_by_number map for the message class.
670
671  The message set message looks like this:
672    message MessageSet {
673      repeated group Item = 1 {
674        required int32 type_id = 2;
675        required string message = 3;
676      }
677    }
678  """
679
680  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
681  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
682  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
683
684  local_ReadTag = ReadTag
685  local_DecodeVarint = _DecodeVarint
686  local_SkipField = SkipField
687
688  def DecodeItem(buffer, pos, end, message, field_dict):
689    message_set_item_start = pos
690    type_id = -1
691    message_start = -1
692    message_end = -1
693
694    # Technically, type_id and message can appear in any order, so we need
695    # a little loop here.
696    while 1:
697      (tag_bytes, pos) = local_ReadTag(buffer, pos)
698      if tag_bytes == type_id_tag_bytes:
699        (type_id, pos) = local_DecodeVarint(buffer, pos)
700      elif tag_bytes == message_tag_bytes:
701        (size, message_start) = local_DecodeVarint(buffer, pos)
702        pos = message_end = message_start + size
703      elif tag_bytes == item_end_tag_bytes:
704        break
705      else:
706        pos = SkipField(buffer, pos, end, tag_bytes)
707        if pos == -1:
708          raise _DecodeError('Missing group end tag.')
709
710    if pos > end:
711      raise _DecodeError('Truncated message.')
712
713    if type_id == -1:
714      raise _DecodeError('MessageSet item missing type_id.')
715    if message_start == -1:
716      raise _DecodeError('MessageSet item missing message.')
717
718    extension = extensions_by_number.get(type_id)
719    if extension is not None:
720      value = field_dict.get(extension)
721      if value is None:
722        value = field_dict.setdefault(
723            extension, extension.message_type._concrete_class())
724      if value._InternalParse(buffer, message_start,message_end) != message_end:
725        # The only reason _InternalParse would return early is if it encountered
726        # an end-group tag.
727        raise _DecodeError('Unexpected end-group tag.')
728    else:
729      if not message._unknown_fields:
730        message._unknown_fields = []
731      message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
732                                      buffer[message_set_item_start:pos]))
733
734    return pos
735
736  return DecodeItem
737
738# --------------------------------------------------------------------
739# Optimization is not as heavy here because calls to SkipField() are rare,
740# except for handling end-group tags.
741
742def _SkipVarint(buffer, pos, end):
743  """Skip a varint value.  Returns the new position."""
744  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
745  # With this code, ord(b'') raises TypeError.  Both are handled in
746  # python_message.py to generate a 'Truncated message' error.
747  while ord(buffer[pos:pos+1]) & 0x80:
748    pos += 1
749  pos += 1
750  if pos > end:
751    raise _DecodeError('Truncated message.')
752  return pos
753
754def _SkipFixed64(buffer, pos, end):
755  """Skip a fixed64 value.  Returns the new position."""
756
757  pos += 8
758  if pos > end:
759    raise _DecodeError('Truncated message.')
760  return pos
761
762def _SkipLengthDelimited(buffer, pos, end):
763  """Skip a length-delimited value.  Returns the new position."""
764
765  (size, pos) = _DecodeVarint(buffer, pos)
766  pos += size
767  if pos > end:
768    raise _DecodeError('Truncated message.')
769  return pos
770
771def _SkipGroup(buffer, pos, end):
772  """Skip sub-group.  Returns the new position."""
773
774  while 1:
775    (tag_bytes, pos) = ReadTag(buffer, pos)
776    new_pos = SkipField(buffer, pos, end, tag_bytes)
777    if new_pos == -1:
778      return pos
779    pos = new_pos
780
781def _EndGroup(buffer, pos, end):
782  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
783
784  return -1
785
786def _SkipFixed32(buffer, pos, end):
787  """Skip a fixed32 value.  Returns the new position."""
788
789  pos += 4
790  if pos > end:
791    raise _DecodeError('Truncated message.')
792  return pos
793
794def _RaiseInvalidWireType(buffer, pos, end):
795  """Skip function for unknown wire types.  Raises an exception."""
796
797  raise _DecodeError('Tag had invalid wire type.')
798
799def _FieldSkipper():
800  """Constructs the SkipField function."""
801
802  WIRETYPE_TO_SKIPPER = [
803      _SkipVarint,
804      _SkipFixed64,
805      _SkipLengthDelimited,
806      _SkipGroup,
807      _EndGroup,
808      _SkipFixed32,
809      _RaiseInvalidWireType,
810      _RaiseInvalidWireType,
811      ]
812
813  wiretype_mask = wire_format.TAG_TYPE_MASK
814
815  def SkipField(buffer, pos, end, tag_bytes):
816    """Skips a field with the specified tag.
817
818    |pos| should point to the byte immediately after the tag.
819
820    Returns:
821        The new position (after the tag value), or -1 if the tag is an end-group
822        tag (in which case the calling loop should break).
823    """
824
825    # The wire type is always in the first byte since varints are little-endian.
826    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
827    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
828
829  return SkipField
830
831SkipField = _FieldSkipper()
832