1// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc.  All rights reserved.
3// http://code.google.com/p/protobuf/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are
7// met:
8//
9//     * Redistributions of source code must retain the above copyright
10// notice, this list of conditions and the following disclaimer.
11//     * Redistributions in binary form must reproduce the above
12// copyright notice, this list of conditions and the following disclaimer
13// in the documentation and/or other materials provided with the
14// distribution.
15//     * Neither the name of Google Inc. nor the names of its
16// contributors may be used to endorse or promote products derived from
17// this software without specific prior written permission.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31// Author: kenton@google.com (Kenton Varda)
32//  Based on original Protocol Buffers design by
33//  Sanjay Ghemawat, Jeff Dean, and others.
34//
35// Contains methods defined in extension_set.h which cannot be part of the
36// lite library because they use descriptors or reflection.
37
38#include <google/protobuf/extension_set.h>
39#include <google/protobuf/descriptor.h>
40#include <google/protobuf/message.h>
41#include <google/protobuf/repeated_field.h>
42#include <google/protobuf/wire_format.h>
43#include <google/protobuf/wire_format_lite_inl.h>
44
45namespace google {
46namespace protobuf {
47namespace internal {
48
49// Implementation of ExtensionFinder which finds extensions in a given
50// DescriptorPool, using the given MessageFactory to construct sub-objects.
51// This class is implemented in extension_set_heavy.cc.
52class DescriptorPoolExtensionFinder : public ExtensionFinder {
53 public:
54  DescriptorPoolExtensionFinder(const DescriptorPool* pool,
55                                MessageFactory* factory,
56                                const Descriptor* containing_type)
57      : pool_(pool), factory_(factory), containing_type_(containing_type) {}
58  virtual ~DescriptorPoolExtensionFinder() {}
59
60  virtual bool Find(int number, ExtensionInfo* output);
61
62 private:
63  const DescriptorPool* pool_;
64  MessageFactory* factory_;
65  const Descriptor* containing_type_;
66};
67
68void ExtensionSet::AppendToList(const Descriptor* containing_type,
69                                const DescriptorPool* pool,
70                                vector<const FieldDescriptor*>* output) const {
71  for (map<int, Extension>::const_iterator iter = extensions_.begin();
72       iter != extensions_.end(); ++iter) {
73    bool has = false;
74    if (iter->second.is_repeated) {
75      has = iter->second.GetSize() > 0;
76    } else {
77      has = !iter->second.is_cleared;
78    }
79
80    if (has) {
81      // TODO(kenton): Looking up each field by number is somewhat unfortunate.
82      //   Is there a better way?  The problem is that descriptors are lazily-
83      //   initialized, so they might not even be constructed until
84      //   AppendToList() is called.
85
86      if (iter->second.descriptor == NULL) {
87        output->push_back(pool->FindExtensionByNumber(
88            containing_type, iter->first));
89      } else {
90        output->push_back(iter->second.descriptor);
91      }
92    }
93  }
94}
95
96inline FieldDescriptor::Type real_type(FieldType type) {
97  GOOGLE_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
98  return static_cast<FieldDescriptor::Type>(type);
99}
100
101inline FieldDescriptor::CppType cpp_type(FieldType type) {
102  return FieldDescriptor::TypeToCppType(
103      static_cast<FieldDescriptor::Type>(type));
104}
105
106#define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                            \
107  GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED     \
108                                  : FieldDescriptor::LABEL_OPTIONAL,      \
109            FieldDescriptor::LABEL_##LABEL);                              \
110  GOOGLE_DCHECK_EQ(cpp_type((EXTENSION).type), FieldDescriptor::CPPTYPE_##CPPTYPE)
111
112const MessageLite& ExtensionSet::GetMessage(int number,
113                                            const Descriptor* message_type,
114                                            MessageFactory* factory) const {
115  map<int, Extension>::const_iterator iter = extensions_.find(number);
116  if (iter == extensions_.end() || iter->second.is_cleared) {
117    // Not present.  Return the default value.
118    return *factory->GetPrototype(message_type);
119  } else {
120    GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
121    return *iter->second.message_value;
122  }
123}
124
125MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
126                                          MessageFactory* factory) {
127  Extension* extension;
128  if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
129    extension->type = descriptor->type();
130    GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
131    extension->is_repeated = false;
132    extension->is_packed = false;
133    const MessageLite* prototype =
134        factory->GetPrototype(descriptor->message_type());
135    GOOGLE_CHECK(prototype != NULL);
136    extension->message_value = prototype->New();
137  } else {
138    GOOGLE_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
139  }
140  extension->is_cleared = false;
141  return extension->message_value;
142}
143
144MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
145                                      MessageFactory* factory) {
146  Extension* extension;
147  if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
148    extension->type = descriptor->type();
149    GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
150    extension->is_repeated = true;
151    extension->repeated_message_value =
152      new RepeatedPtrField<MessageLite>();
153  } else {
154    GOOGLE_DCHECK_TYPE(*extension, REPEATED, MESSAGE);
155  }
156
157  // RepeatedPtrField<Message> does not know how to Add() since it cannot
158  // allocate an abstract object, so we have to be tricky.
159  MessageLite* result = extension->repeated_message_value
160      ->AddFromCleared<internal::GenericTypeHandler<MessageLite> >();
161  if (result == NULL) {
162    const MessageLite* prototype;
163    if (extension->repeated_message_value->size() == 0) {
164      prototype = factory->GetPrototype(descriptor->message_type());
165      GOOGLE_CHECK(prototype != NULL);
166    } else {
167      prototype = &extension->repeated_message_value->Get(0);
168    }
169    result = prototype->New();
170    extension->repeated_message_value->AddAllocated(result);
171  }
172  return result;
173}
174
175static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
176  return reinterpret_cast<const EnumDescriptor*>(arg)
177      ->FindValueByNumber(number) != NULL;
178}
179
180bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
181  const FieldDescriptor* extension =
182      pool_->FindExtensionByNumber(containing_type_, number);
183  if (extension == NULL) {
184    return false;
185  } else {
186    output->type = extension->type();
187    output->is_repeated = extension->is_repeated();
188    output->is_packed = extension->options().packed();
189    output->descriptor = extension;
190    if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
191      output->message_prototype =
192          factory_->GetPrototype(extension->message_type());
193      GOOGLE_CHECK(output->message_prototype != NULL)
194          << "Extension factory's GetPrototype() returned NULL for extension: "
195          << extension->full_name();
196    } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
197      output->enum_validity_check.func = ValidateEnumUsingDescriptor;
198      output->enum_validity_check.arg = extension->enum_type();
199    }
200
201    return true;
202  }
203}
204
205bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
206                              const Message* containing_type,
207                              UnknownFieldSet* unknown_fields) {
208  UnknownFieldSetFieldSkipper skipper(unknown_fields);
209  if (input->GetExtensionPool() == NULL) {
210    GeneratedExtensionFinder finder(containing_type);
211    return ParseField(tag, input, &finder, &skipper);
212  } else {
213    DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
214                                         input->GetExtensionFactory(),
215                                         containing_type->GetDescriptor());
216    return ParseField(tag, input, &finder, &skipper);
217  }
218}
219
220bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
221                                   const Message* containing_type,
222                                   UnknownFieldSet* unknown_fields) {
223  UnknownFieldSetFieldSkipper skipper(unknown_fields);
224  if (input->GetExtensionPool() == NULL) {
225    GeneratedExtensionFinder finder(containing_type);
226    return ParseMessageSet(input, &finder, &skipper);
227  } else {
228    DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
229                                         input->GetExtensionFactory(),
230                                         containing_type->GetDescriptor());
231    return ParseMessageSet(input, &finder, &skipper);
232  }
233}
234
235int ExtensionSet::SpaceUsedExcludingSelf() const {
236  int total_size =
237      extensions_.size() * sizeof(map<int, Extension>::value_type);
238  for (map<int, Extension>::const_iterator iter = extensions_.begin(),
239       end = extensions_.end();
240       iter != end;
241       ++iter) {
242    total_size += iter->second.SpaceUsedExcludingSelf();
243  }
244  return total_size;
245}
246
247inline int ExtensionSet::RepeatedMessage_SpaceUsedExcludingSelf(
248    RepeatedPtrFieldBase* field) {
249  return field->SpaceUsedExcludingSelf<GenericTypeHandler<Message> >();
250}
251
252int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
253  int total_size = 0;
254  if (is_repeated) {
255    switch (cpp_type(type)) {
256#define HANDLE_TYPE(UPPERCASE, LOWERCASE)                          \
257      case FieldDescriptor::CPPTYPE_##UPPERCASE:                   \
258        total_size += sizeof(*repeated_##LOWERCASE##_value) +      \
259            repeated_##LOWERCASE##_value->SpaceUsedExcludingSelf();\
260        break
261
262      HANDLE_TYPE(  INT32,   int32);
263      HANDLE_TYPE(  INT64,   int64);
264      HANDLE_TYPE( UINT32,  uint32);
265      HANDLE_TYPE( UINT64,  uint64);
266      HANDLE_TYPE(  FLOAT,   float);
267      HANDLE_TYPE( DOUBLE,  double);
268      HANDLE_TYPE(   BOOL,    bool);
269      HANDLE_TYPE(   ENUM,    enum);
270      HANDLE_TYPE( STRING,  string);
271#undef HANDLE_TYPE
272
273      case FieldDescriptor::CPPTYPE_MESSAGE:
274        // repeated_message_value is actually a RepeatedPtrField<MessageLite>,
275        // but MessageLite has no SpaceUsed(), so we must directly call
276        // RepeatedPtrFieldBase::SpaceUsedExcludingSelf() with a different type
277        // handler.
278        total_size += sizeof(*repeated_message_value) +
279            RepeatedMessage_SpaceUsedExcludingSelf(repeated_message_value);
280        break;
281    }
282  } else {
283    switch (cpp_type(type)) {
284      case FieldDescriptor::CPPTYPE_STRING:
285        total_size += sizeof(*string_value) +
286                      StringSpaceUsedExcludingSelf(*string_value);
287        break;
288      case FieldDescriptor::CPPTYPE_MESSAGE:
289        total_size += down_cast<Message*>(message_value)->SpaceUsed();
290        break;
291      default:
292        // No extra storage costs for primitive types.
293        break;
294    }
295  }
296  return total_size;
297}
298
299// The Serialize*ToArray methods are only needed in the heavy library, as
300// the lite library only generates SerializeWithCachedSizes.
301uint8* ExtensionSet::SerializeWithCachedSizesToArray(
302    int start_field_number, int end_field_number,
303    uint8* target) const {
304  map<int, Extension>::const_iterator iter;
305  for (iter = extensions_.lower_bound(start_field_number);
306       iter != extensions_.end() && iter->first < end_field_number;
307       ++iter) {
308    target = iter->second.SerializeFieldWithCachedSizesToArray(iter->first,
309                                                               target);
310  }
311  return target;
312}
313
314uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
315    uint8* target) const {
316  map<int, Extension>::const_iterator iter;
317  for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) {
318    target = iter->second.SerializeMessageSetItemWithCachedSizesToArray(
319        iter->first, target);
320  }
321  return target;
322}
323
324uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray(
325    int number, uint8* target) const {
326  if (is_repeated) {
327    if (is_packed) {
328      if (cached_size == 0) return target;
329
330      target = WireFormatLite::WriteTagToArray(number,
331          WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
332      target = WireFormatLite::WriteInt32NoTagToArray(cached_size, target);
333
334      switch (real_type(type)) {
335#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
336        case FieldDescriptor::TYPE_##UPPERCASE:                             \
337          for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
338            target = WireFormatLite::Write##CAMELCASE##NoTagToArray(        \
339              repeated_##LOWERCASE##_value->Get(i), target);                \
340          }                                                                 \
341          break
342
343        HANDLE_TYPE(   INT32,    Int32,   int32);
344        HANDLE_TYPE(   INT64,    Int64,   int64);
345        HANDLE_TYPE(  UINT32,   UInt32,  uint32);
346        HANDLE_TYPE(  UINT64,   UInt64,  uint64);
347        HANDLE_TYPE(  SINT32,   SInt32,   int32);
348        HANDLE_TYPE(  SINT64,   SInt64,   int64);
349        HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
350        HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
351        HANDLE_TYPE(SFIXED32, SFixed32,   int32);
352        HANDLE_TYPE(SFIXED64, SFixed64,   int64);
353        HANDLE_TYPE(   FLOAT,    Float,   float);
354        HANDLE_TYPE(  DOUBLE,   Double,  double);
355        HANDLE_TYPE(    BOOL,     Bool,    bool);
356        HANDLE_TYPE(    ENUM,     Enum,    enum);
357#undef HANDLE_TYPE
358
359        case WireFormatLite::TYPE_STRING:
360        case WireFormatLite::TYPE_BYTES:
361        case WireFormatLite::TYPE_GROUP:
362        case WireFormatLite::TYPE_MESSAGE:
363          GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
364          break;
365      }
366    } else {
367      switch (real_type(type)) {
368#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
369        case FieldDescriptor::TYPE_##UPPERCASE:                             \
370          for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
371            target = WireFormatLite::Write##CAMELCASE##ToArray(number,      \
372              repeated_##LOWERCASE##_value->Get(i), target);                \
373          }                                                                 \
374          break
375
376        HANDLE_TYPE(   INT32,    Int32,   int32);
377        HANDLE_TYPE(   INT64,    Int64,   int64);
378        HANDLE_TYPE(  UINT32,   UInt32,  uint32);
379        HANDLE_TYPE(  UINT64,   UInt64,  uint64);
380        HANDLE_TYPE(  SINT32,   SInt32,   int32);
381        HANDLE_TYPE(  SINT64,   SInt64,   int64);
382        HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
383        HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
384        HANDLE_TYPE(SFIXED32, SFixed32,   int32);
385        HANDLE_TYPE(SFIXED64, SFixed64,   int64);
386        HANDLE_TYPE(   FLOAT,    Float,   float);
387        HANDLE_TYPE(  DOUBLE,   Double,  double);
388        HANDLE_TYPE(    BOOL,     Bool,    bool);
389        HANDLE_TYPE(  STRING,   String,  string);
390        HANDLE_TYPE(   BYTES,    Bytes,  string);
391        HANDLE_TYPE(    ENUM,     Enum,    enum);
392        HANDLE_TYPE(   GROUP,    Group, message);
393        HANDLE_TYPE( MESSAGE,  Message, message);
394#undef HANDLE_TYPE
395      }
396    }
397  } else if (!is_cleared) {
398    switch (real_type(type)) {
399#define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)                 \
400      case FieldDescriptor::TYPE_##UPPERCASE:                    \
401        target = WireFormatLite::Write##CAMELCASE##ToArray(      \
402            number, VALUE, target); \
403        break
404
405      HANDLE_TYPE(   INT32,    Int32,    int32_value);
406      HANDLE_TYPE(   INT64,    Int64,    int64_value);
407      HANDLE_TYPE(  UINT32,   UInt32,   uint32_value);
408      HANDLE_TYPE(  UINT64,   UInt64,   uint64_value);
409      HANDLE_TYPE(  SINT32,   SInt32,    int32_value);
410      HANDLE_TYPE(  SINT64,   SInt64,    int64_value);
411      HANDLE_TYPE( FIXED32,  Fixed32,   uint32_value);
412      HANDLE_TYPE( FIXED64,  Fixed64,   uint64_value);
413      HANDLE_TYPE(SFIXED32, SFixed32,    int32_value);
414      HANDLE_TYPE(SFIXED64, SFixed64,    int64_value);
415      HANDLE_TYPE(   FLOAT,    Float,    float_value);
416      HANDLE_TYPE(  DOUBLE,   Double,   double_value);
417      HANDLE_TYPE(    BOOL,     Bool,     bool_value);
418      HANDLE_TYPE(  STRING,   String,  *string_value);
419      HANDLE_TYPE(   BYTES,    Bytes,  *string_value);
420      HANDLE_TYPE(    ENUM,     Enum,     enum_value);
421      HANDLE_TYPE(   GROUP,    Group, *message_value);
422      HANDLE_TYPE( MESSAGE,  Message, *message_value);
423#undef HANDLE_TYPE
424    }
425  }
426  return target;
427}
428
429uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray(
430    int number,
431    uint8* target) const {
432  if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
433    // Not a valid MessageSet extension, but serialize it the normal way.
434    GOOGLE_LOG(WARNING) << "Invalid message set extension.";
435    return SerializeFieldWithCachedSizesToArray(number, target);
436  }
437
438  if (is_cleared) return target;
439
440  // Start group.
441  target = io::CodedOutputStream::WriteTagToArray(
442      WireFormatLite::kMessageSetItemStartTag, target);
443  // Write type ID.
444  target = WireFormatLite::WriteUInt32ToArray(
445      WireFormatLite::kMessageSetTypeIdNumber, number, target);
446  // Write message.
447  target = WireFormatLite::WriteMessageToArray(
448      WireFormatLite::kMessageSetMessageNumber, *message_value, target);
449  // End group.
450  target = io::CodedOutputStream::WriteTagToArray(
451      WireFormatLite::kMessageSetItemEndTag, target);
452  return target;
453}
454
455}  // namespace internal
456}  // namespace protobuf
457}  // namespace google
458