1// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc.  All rights reserved.
3// https://developers.google.com/protocol-buffers/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are
7// met:
8//
9//     * Redistributions of source code must retain the above copyright
10// notice, this list of conditions and the following disclaimer.
11//     * Redistributions in binary form must reproduce the above
12// copyright notice, this list of conditions and the following disclaimer
13// in the documentation and/or other materials provided with the
14// distribution.
15//     * Neither the name of Google Inc. nor the names of its
16// contributors may be used to endorse or promote products derived from
17// this software without specific prior written permission.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31// Author: kenton@google.com (Kenton Varda)
32//  Based on original Protocol Buffers design by
33//  Sanjay Ghemawat, Jeff Dean, and others.
34//
35// Contains methods defined in extension_set.h which cannot be part of the
36// lite library because they use descriptors or reflection.
37
38#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
39#include <google/protobuf/descriptor.h>
40#include <google/protobuf/extension_set.h>
41#include <google/protobuf/message.h>
42#include <google/protobuf/repeated_field.h>
43#include <google/protobuf/wire_format.h>
44#include <google/protobuf/wire_format_lite_inl.h>
45
46namespace google {
47
48namespace protobuf {
49namespace internal {
50
51// A FieldSkipper used to store unknown MessageSet fields into UnknownFieldSet.
52class MessageSetFieldSkipper
53    : public UnknownFieldSetFieldSkipper {
54 public:
55  explicit MessageSetFieldSkipper(UnknownFieldSet* unknown_fields)
56      : UnknownFieldSetFieldSkipper(unknown_fields) {}
57  virtual ~MessageSetFieldSkipper() {}
58
59  virtual bool SkipMessageSetField(io::CodedInputStream* input,
60                                   int field_number);
61};
62bool MessageSetFieldSkipper::SkipMessageSetField(
63    io::CodedInputStream* input, int field_number) {
64  uint32 length;
65  if (!input->ReadVarint32(&length)) return false;
66  if (unknown_fields_ == NULL) {
67    return input->Skip(length);
68  } else {
69    return input->ReadString(
70        unknown_fields_->AddLengthDelimited(field_number), length);
71  }
72}
73
74
75// Implementation of ExtensionFinder which finds extensions in a given
76// DescriptorPool, using the given MessageFactory to construct sub-objects.
77// This class is implemented in extension_set_heavy.cc.
78class DescriptorPoolExtensionFinder : public ExtensionFinder {
79 public:
80  DescriptorPoolExtensionFinder(const DescriptorPool* pool,
81                                MessageFactory* factory,
82                                const Descriptor* containing_type)
83      : pool_(pool), factory_(factory), containing_type_(containing_type) {}
84  virtual ~DescriptorPoolExtensionFinder() {}
85
86  virtual bool Find(int number, ExtensionInfo* output);
87
88 private:
89  const DescriptorPool* pool_;
90  MessageFactory* factory_;
91  const Descriptor* containing_type_;
92};
93
94void ExtensionSet::AppendToList(const Descriptor* containing_type,
95                                const DescriptorPool* pool,
96                                std::vector<const FieldDescriptor*>* output) const {
97  for (map<int, Extension>::const_iterator iter = extensions_.begin();
98       iter != extensions_.end(); ++iter) {
99    bool has = false;
100    if (iter->second.is_repeated) {
101      has = iter->second.GetSize() > 0;
102    } else {
103      has = !iter->second.is_cleared;
104    }
105
106    if (has) {
107      // TODO(kenton): Looking up each field by number is somewhat unfortunate.
108      //   Is there a better way?  The problem is that descriptors are lazily-
109      //   initialized, so they might not even be constructed until
110      //   AppendToList() is called.
111
112      if (iter->second.descriptor == NULL) {
113        output->push_back(pool->FindExtensionByNumber(
114            containing_type, iter->first));
115      } else {
116        output->push_back(iter->second.descriptor);
117      }
118    }
119  }
120}
121
122inline FieldDescriptor::Type real_type(FieldType type) {
123  GOOGLE_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
124  return static_cast<FieldDescriptor::Type>(type);
125}
126
127inline FieldDescriptor::CppType cpp_type(FieldType type) {
128  return FieldDescriptor::TypeToCppType(
129      static_cast<FieldDescriptor::Type>(type));
130}
131
132inline WireFormatLite::FieldType field_type(FieldType type) {
133  GOOGLE_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
134  return static_cast<WireFormatLite::FieldType>(type);
135}
136
137#define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                            \
138  GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED     \
139                                  : FieldDescriptor::LABEL_OPTIONAL,      \
140            FieldDescriptor::LABEL_##LABEL);                              \
141  GOOGLE_DCHECK_EQ(cpp_type((EXTENSION).type), FieldDescriptor::CPPTYPE_##CPPTYPE)
142
143const MessageLite& ExtensionSet::GetMessage(int number,
144                                            const Descriptor* message_type,
145                                            MessageFactory* factory) const {
146  map<int, Extension>::const_iterator iter = extensions_.find(number);
147  if (iter == extensions_.end() || iter->second.is_cleared) {
148    // Not present.  Return the default value.
149    return *factory->GetPrototype(message_type);
150  } else {
151    GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
152    if (iter->second.is_lazy) {
153      return iter->second.lazymessage_value->GetMessage(
154          *factory->GetPrototype(message_type));
155    } else {
156      return *iter->second.message_value;
157    }
158  }
159}
160
161MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
162                                          MessageFactory* factory) {
163  Extension* extension;
164  if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
165    extension->type = descriptor->type();
166    GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
167    extension->is_repeated = false;
168    extension->is_packed = false;
169    const MessageLite* prototype =
170        factory->GetPrototype(descriptor->message_type());
171    extension->is_lazy = false;
172    extension->message_value = prototype->New();
173    extension->is_cleared = false;
174    return extension->message_value;
175  } else {
176    GOOGLE_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
177    extension->is_cleared = false;
178    if (extension->is_lazy) {
179      return extension->lazymessage_value->MutableMessage(
180          *factory->GetPrototype(descriptor->message_type()));
181    } else {
182      return extension->message_value;
183    }
184  }
185}
186
187MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
188                                          MessageFactory* factory) {
189  map<int, Extension>::iterator iter = extensions_.find(descriptor->number());
190  if (iter == extensions_.end()) {
191    // Not present.  Return NULL.
192    return NULL;
193  } else {
194    GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
195    MessageLite* ret = NULL;
196    if (iter->second.is_lazy) {
197      ret = iter->second.lazymessage_value->ReleaseMessage(
198          *factory->GetPrototype(descriptor->message_type()));
199      delete iter->second.lazymessage_value;
200    } else {
201      ret = iter->second.message_value;
202    }
203    extensions_.erase(descriptor->number());
204    return ret;
205  }
206}
207
208MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
209                                      MessageFactory* factory) {
210  Extension* extension;
211  if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
212    extension->type = descriptor->type();
213    GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
214    extension->is_repeated = true;
215    extension->repeated_message_value =
216      new RepeatedPtrField<MessageLite>();
217  } else {
218    GOOGLE_DCHECK_TYPE(*extension, REPEATED, MESSAGE);
219  }
220
221  // RepeatedPtrField<Message> does not know how to Add() since it cannot
222  // allocate an abstract object, so we have to be tricky.
223  MessageLite* result = extension->repeated_message_value
224      ->AddFromCleared<GenericTypeHandler<MessageLite> >();
225  if (result == NULL) {
226    const MessageLite* prototype;
227    if (extension->repeated_message_value->size() == 0) {
228      prototype = factory->GetPrototype(descriptor->message_type());
229      GOOGLE_CHECK(prototype != NULL);
230    } else {
231      prototype = &extension->repeated_message_value->Get(0);
232    }
233    result = prototype->New();
234    extension->repeated_message_value->AddAllocated(result);
235  }
236  return result;
237}
238
239static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
240  return reinterpret_cast<const EnumDescriptor*>(arg)
241      ->FindValueByNumber(number) != NULL;
242}
243
244bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
245  const FieldDescriptor* extension =
246      pool_->FindExtensionByNumber(containing_type_, number);
247  if (extension == NULL) {
248    return false;
249  } else {
250    output->type = extension->type();
251    output->is_repeated = extension->is_repeated();
252    output->is_packed = extension->options().packed();
253    output->descriptor = extension;
254    if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
255      output->message_prototype =
256          factory_->GetPrototype(extension->message_type());
257      GOOGLE_CHECK(output->message_prototype != NULL)
258          << "Extension factory's GetPrototype() returned NULL for extension: "
259          << extension->full_name();
260    } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
261      output->enum_validity_check.func = ValidateEnumUsingDescriptor;
262      output->enum_validity_check.arg = extension->enum_type();
263    }
264
265    return true;
266  }
267}
268
269bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
270                              const Message* containing_type,
271                              UnknownFieldSet* unknown_fields) {
272  UnknownFieldSetFieldSkipper skipper(unknown_fields);
273  if (input->GetExtensionPool() == NULL) {
274    GeneratedExtensionFinder finder(containing_type);
275    return ParseField(tag, input, &finder, &skipper);
276  } else {
277    DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
278                                         input->GetExtensionFactory(),
279                                         containing_type->GetDescriptor());
280    return ParseField(tag, input, &finder, &skipper);
281  }
282}
283
284bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
285                                   const Message* containing_type,
286                                   UnknownFieldSet* unknown_fields) {
287  MessageSetFieldSkipper skipper(unknown_fields);
288  if (input->GetExtensionPool() == NULL) {
289    GeneratedExtensionFinder finder(containing_type);
290    return ParseMessageSet(input, &finder, &skipper);
291  } else {
292    DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
293                                         input->GetExtensionFactory(),
294                                         containing_type->GetDescriptor());
295    return ParseMessageSet(input, &finder, &skipper);
296  }
297}
298
299int ExtensionSet::SpaceUsedExcludingSelf() const {
300  int total_size =
301      extensions_.size() * sizeof(map<int, Extension>::value_type);
302  for (map<int, Extension>::const_iterator iter = extensions_.begin(),
303       end = extensions_.end();
304       iter != end;
305       ++iter) {
306    total_size += iter->second.SpaceUsedExcludingSelf();
307  }
308  return total_size;
309}
310
311inline int ExtensionSet::RepeatedMessage_SpaceUsedExcludingSelf(
312    RepeatedPtrFieldBase* field) {
313  return field->SpaceUsedExcludingSelf<GenericTypeHandler<Message> >();
314}
315
316int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
317  int total_size = 0;
318  if (is_repeated) {
319    switch (cpp_type(type)) {
320#define HANDLE_TYPE(UPPERCASE, LOWERCASE)                          \
321      case FieldDescriptor::CPPTYPE_##UPPERCASE:                   \
322        total_size += sizeof(*repeated_##LOWERCASE##_value) +      \
323            repeated_##LOWERCASE##_value->SpaceUsedExcludingSelf();\
324        break
325
326      HANDLE_TYPE(  INT32,   int32);
327      HANDLE_TYPE(  INT64,   int64);
328      HANDLE_TYPE( UINT32,  uint32);
329      HANDLE_TYPE( UINT64,  uint64);
330      HANDLE_TYPE(  FLOAT,   float);
331      HANDLE_TYPE( DOUBLE,  double);
332      HANDLE_TYPE(   BOOL,    bool);
333      HANDLE_TYPE(   ENUM,    enum);
334      HANDLE_TYPE( STRING,  string);
335#undef HANDLE_TYPE
336
337      case FieldDescriptor::CPPTYPE_MESSAGE:
338        // repeated_message_value is actually a RepeatedPtrField<MessageLite>,
339        // but MessageLite has no SpaceUsed(), so we must directly call
340        // RepeatedPtrFieldBase::SpaceUsedExcludingSelf() with a different type
341        // handler.
342        total_size += sizeof(*repeated_message_value) +
343            RepeatedMessage_SpaceUsedExcludingSelf(repeated_message_value);
344        break;
345    }
346  } else {
347    switch (cpp_type(type)) {
348      case FieldDescriptor::CPPTYPE_STRING:
349        total_size += sizeof(*string_value) +
350                      StringSpaceUsedExcludingSelf(*string_value);
351        break;
352      case FieldDescriptor::CPPTYPE_MESSAGE:
353        if (is_lazy) {
354          total_size += lazymessage_value->SpaceUsed();
355        } else {
356          total_size += down_cast<Message*>(message_value)->SpaceUsed();
357        }
358        break;
359      default:
360        // No extra storage costs for primitive types.
361        break;
362    }
363  }
364  return total_size;
365}
366
367// The Serialize*ToArray methods are only needed in the heavy library, as
368// the lite library only generates SerializeWithCachedSizes.
369uint8* ExtensionSet::SerializeWithCachedSizesToArray(
370    int start_field_number, int end_field_number,
371    uint8* target) const {
372  map<int, Extension>::const_iterator iter;
373  for (iter = extensions_.lower_bound(start_field_number);
374       iter != extensions_.end() && iter->first < end_field_number;
375       ++iter) {
376    target = iter->second.SerializeFieldWithCachedSizesToArray(iter->first,
377                                                               target);
378  }
379  return target;
380}
381
382uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
383    uint8* target) const {
384  map<int, Extension>::const_iterator iter;
385  for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) {
386    target = iter->second.SerializeMessageSetItemWithCachedSizesToArray(
387        iter->first, target);
388  }
389  return target;
390}
391
392uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray(
393    int number, uint8* target) const {
394  if (is_repeated) {
395    if (is_packed) {
396      if (cached_size == 0) return target;
397
398      target = WireFormatLite::WriteTagToArray(number,
399          WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
400      target = WireFormatLite::WriteInt32NoTagToArray(cached_size, target);
401
402      switch (real_type(type)) {
403#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
404        case FieldDescriptor::TYPE_##UPPERCASE:                             \
405          for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
406            target = WireFormatLite::Write##CAMELCASE##NoTagToArray(        \
407              repeated_##LOWERCASE##_value->Get(i), target);                \
408          }                                                                 \
409          break
410
411        HANDLE_TYPE(   INT32,    Int32,   int32);
412        HANDLE_TYPE(   INT64,    Int64,   int64);
413        HANDLE_TYPE(  UINT32,   UInt32,  uint32);
414        HANDLE_TYPE(  UINT64,   UInt64,  uint64);
415        HANDLE_TYPE(  SINT32,   SInt32,   int32);
416        HANDLE_TYPE(  SINT64,   SInt64,   int64);
417        HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
418        HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
419        HANDLE_TYPE(SFIXED32, SFixed32,   int32);
420        HANDLE_TYPE(SFIXED64, SFixed64,   int64);
421        HANDLE_TYPE(   FLOAT,    Float,   float);
422        HANDLE_TYPE(  DOUBLE,   Double,  double);
423        HANDLE_TYPE(    BOOL,     Bool,    bool);
424        HANDLE_TYPE(    ENUM,     Enum,    enum);
425#undef HANDLE_TYPE
426
427        case WireFormatLite::TYPE_STRING:
428        case WireFormatLite::TYPE_BYTES:
429        case WireFormatLite::TYPE_GROUP:
430        case WireFormatLite::TYPE_MESSAGE:
431          GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
432          break;
433      }
434    } else {
435      switch (real_type(type)) {
436#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
437        case FieldDescriptor::TYPE_##UPPERCASE:                             \
438          for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
439            target = WireFormatLite::Write##CAMELCASE##ToArray(number,      \
440              repeated_##LOWERCASE##_value->Get(i), target);                \
441          }                                                                 \
442          break
443
444        HANDLE_TYPE(   INT32,    Int32,   int32);
445        HANDLE_TYPE(   INT64,    Int64,   int64);
446        HANDLE_TYPE(  UINT32,   UInt32,  uint32);
447        HANDLE_TYPE(  UINT64,   UInt64,  uint64);
448        HANDLE_TYPE(  SINT32,   SInt32,   int32);
449        HANDLE_TYPE(  SINT64,   SInt64,   int64);
450        HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
451        HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
452        HANDLE_TYPE(SFIXED32, SFixed32,   int32);
453        HANDLE_TYPE(SFIXED64, SFixed64,   int64);
454        HANDLE_TYPE(   FLOAT,    Float,   float);
455        HANDLE_TYPE(  DOUBLE,   Double,  double);
456        HANDLE_TYPE(    BOOL,     Bool,    bool);
457        HANDLE_TYPE(  STRING,   String,  string);
458        HANDLE_TYPE(   BYTES,    Bytes,  string);
459        HANDLE_TYPE(    ENUM,     Enum,    enum);
460        HANDLE_TYPE(   GROUP,    Group, message);
461        HANDLE_TYPE( MESSAGE,  Message, message);
462#undef HANDLE_TYPE
463      }
464    }
465  } else if (!is_cleared) {
466    switch (real_type(type)) {
467#define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)                 \
468      case FieldDescriptor::TYPE_##UPPERCASE:                    \
469        target = WireFormatLite::Write##CAMELCASE##ToArray(      \
470            number, VALUE, target); \
471        break
472
473      HANDLE_TYPE(   INT32,    Int32,    int32_value);
474      HANDLE_TYPE(   INT64,    Int64,    int64_value);
475      HANDLE_TYPE(  UINT32,   UInt32,   uint32_value);
476      HANDLE_TYPE(  UINT64,   UInt64,   uint64_value);
477      HANDLE_TYPE(  SINT32,   SInt32,    int32_value);
478      HANDLE_TYPE(  SINT64,   SInt64,    int64_value);
479      HANDLE_TYPE( FIXED32,  Fixed32,   uint32_value);
480      HANDLE_TYPE( FIXED64,  Fixed64,   uint64_value);
481      HANDLE_TYPE(SFIXED32, SFixed32,    int32_value);
482      HANDLE_TYPE(SFIXED64, SFixed64,    int64_value);
483      HANDLE_TYPE(   FLOAT,    Float,    float_value);
484      HANDLE_TYPE(  DOUBLE,   Double,   double_value);
485      HANDLE_TYPE(    BOOL,     Bool,     bool_value);
486      HANDLE_TYPE(  STRING,   String,  *string_value);
487      HANDLE_TYPE(   BYTES,    Bytes,  *string_value);
488      HANDLE_TYPE(    ENUM,     Enum,     enum_value);
489      HANDLE_TYPE(   GROUP,    Group, *message_value);
490#undef HANDLE_TYPE
491      case FieldDescriptor::TYPE_MESSAGE:
492        if (is_lazy) {
493          target = lazymessage_value->WriteMessageToArray(number, target);
494        } else {
495          target = WireFormatLite::WriteMessageToArray(
496              number, *message_value, target);
497        }
498        break;
499    }
500  }
501  return target;
502}
503
504uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray(
505    int number,
506    uint8* target) const {
507  if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
508    // Not a valid MessageSet extension, but serialize it the normal way.
509    GOOGLE_LOG(WARNING) << "Invalid message set extension.";
510    return SerializeFieldWithCachedSizesToArray(number, target);
511  }
512
513  if (is_cleared) return target;
514
515  // Start group.
516  target = io::CodedOutputStream::WriteTagToArray(
517      WireFormatLite::kMessageSetItemStartTag, target);
518  // Write type ID.
519  target = WireFormatLite::WriteUInt32ToArray(
520      WireFormatLite::kMessageSetTypeIdNumber, number, target);
521  // Write message.
522  if (is_lazy) {
523    target = lazymessage_value->WriteMessageToArray(
524        WireFormatLite::kMessageSetMessageNumber, target);
525  } else {
526    target = WireFormatLite::WriteMessageToArray(
527        WireFormatLite::kMessageSetMessageNumber, *message_value, target);
528  }
529  // End group.
530  target = io::CodedOutputStream::WriteTagToArray(
531      WireFormatLite::kMessageSetItemEndTag, target);
532  return target;
533}
534
535
536bool ExtensionSet::ParseFieldMaybeLazily(
537    int wire_type, int field_number, io::CodedInputStream* input,
538    ExtensionFinder* extension_finder,
539    MessageSetFieldSkipper* field_skipper) {
540  return ParseField(WireFormatLite::MakeTag(
541      field_number, static_cast<WireFormatLite::WireType>(wire_type)),
542                    input, extension_finder, field_skipper);
543}
544
545bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
546                                   ExtensionFinder* extension_finder,
547                                   MessageSetFieldSkipper* field_skipper) {
548  while (true) {
549    const uint32 tag = input->ReadTag();
550    switch (tag) {
551      case 0:
552        return true;
553      case WireFormatLite::kMessageSetItemStartTag:
554        if (!ParseMessageSetItem(input, extension_finder, field_skipper)) {
555          return false;
556        }
557        break;
558      default:
559        if (!ParseField(tag, input, extension_finder, field_skipper)) {
560          return false;
561        }
562        break;
563    }
564  }
565}
566
567bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
568                                   const MessageLite* containing_type) {
569  MessageSetFieldSkipper skipper(NULL);
570  GeneratedExtensionFinder finder(containing_type);
571  return ParseMessageSet(input, &finder, &skipper);
572}
573
574bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
575                                       ExtensionFinder* extension_finder,
576                                       MessageSetFieldSkipper* field_skipper) {
577  // TODO(kenton):  It would be nice to share code between this and
578  // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the
579  // differences would be hard to factor out.
580
581  // This method parses a group which should contain two fields:
582  //   required int32 type_id = 2;
583  //   required data message = 3;
584
585  uint32 last_type_id = 0;
586
587  // If we see message data before the type_id, we'll append it to this so
588  // we can parse it later.
589  string message_data;
590
591  while (true) {
592    const uint32 tag = input->ReadTag();
593    if (tag == 0) return false;
594
595    switch (tag) {
596      case WireFormatLite::kMessageSetTypeIdTag: {
597        uint32 type_id;
598        if (!input->ReadVarint32(&type_id)) return false;
599        last_type_id = type_id;
600
601        if (!message_data.empty()) {
602          // We saw some message data before the type_id.  Have to parse it
603          // now.
604          io::CodedInputStream sub_input(
605              reinterpret_cast<const uint8*>(message_data.data()),
606              message_data.size());
607          if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
608                                     last_type_id, &sub_input,
609                                     extension_finder, field_skipper)) {
610            return false;
611          }
612          message_data.clear();
613        }
614
615        break;
616      }
617
618      case WireFormatLite::kMessageSetMessageTag: {
619        if (last_type_id == 0) {
620          // We haven't seen a type_id yet.  Append this data to message_data.
621          string temp;
622          uint32 length;
623          if (!input->ReadVarint32(&length)) return false;
624          if (!input->ReadString(&temp, length)) return false;
625          io::StringOutputStream output_stream(&message_data);
626          io::CodedOutputStream coded_output(&output_stream);
627          coded_output.WriteVarint32(length);
628          coded_output.WriteString(temp);
629        } else {
630          // Already saw type_id, so we can parse this directly.
631          if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
632                                     last_type_id, input,
633                                     extension_finder, field_skipper)) {
634            return false;
635          }
636        }
637
638        break;
639      }
640
641      case WireFormatLite::kMessageSetItemEndTag: {
642        return true;
643      }
644
645      default: {
646        if (!field_skipper->SkipField(input, tag)) return false;
647      }
648    }
649  }
650}
651
652void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes(
653    int number,
654    io::CodedOutputStream* output) const {
655  if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
656    // Not a valid MessageSet extension, but serialize it the normal way.
657    SerializeFieldWithCachedSizes(number, output);
658    return;
659  }
660
661  if (is_cleared) return;
662
663  // Start group.
664  output->WriteTag(WireFormatLite::kMessageSetItemStartTag);
665
666  // Write type ID.
667  WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
668                              number,
669                              output);
670  // Write message.
671  if (is_lazy) {
672    lazymessage_value->WriteMessage(
673        WireFormatLite::kMessageSetMessageNumber, output);
674  } else {
675    WireFormatLite::WriteMessageMaybeToArray(
676        WireFormatLite::kMessageSetMessageNumber,
677        *message_value,
678        output);
679  }
680
681  // End group.
682  output->WriteTag(WireFormatLite::kMessageSetItemEndTag);
683}
684
685int ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
686  if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
687    // Not a valid MessageSet extension, but compute the byte size for it the
688    // normal way.
689    return ByteSize(number);
690  }
691
692  if (is_cleared) return 0;
693
694  int our_size = WireFormatLite::kMessageSetItemTagsSize;
695
696  // type_id
697  our_size += io::CodedOutputStream::VarintSize32(number);
698
699  // message
700  int message_size = 0;
701  if (is_lazy) {
702    message_size = lazymessage_value->ByteSize();
703  } else {
704    message_size = message_value->ByteSize();
705  }
706
707  our_size += io::CodedOutputStream::VarintSize32(message_size);
708  our_size += message_size;
709
710  return our_size;
711}
712
713void ExtensionSet::SerializeMessageSetWithCachedSizes(
714    io::CodedOutputStream* output) const {
715  for (map<int, Extension>::const_iterator iter = extensions_.begin();
716       iter != extensions_.end(); ++iter) {
717    iter->second.SerializeMessageSetItemWithCachedSizes(iter->first, output);
718  }
719}
720
721int ExtensionSet::MessageSetByteSize() const {
722  int total_size = 0;
723
724  for (map<int, Extension>::const_iterator iter = extensions_.begin();
725       iter != extensions_.end(); ++iter) {
726    total_size += iter->second.MessageSetItemByteSize(iter->first);
727  }
728
729  return total_size;
730}
731
732}  // namespace internal
733}  // namespace protobuf
734}  // namespace google
735