1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# http://code.google.com/p/protobuf/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
32
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
36
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
40
41Below is a straightforward example on how to use this class:
42
43  pool = DescriptorPool()
44  file_descriptor_protos = [ ... ]
45  for file_descriptor_proto in file_descriptor_protos:
46    pool.Add(file_descriptor_proto)
47  my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
48
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
52"""
53
54__author__ = 'matthewtoia@google.com (Matt Toia)'
55
56from google.protobuf import descriptor_pb2
57from google.protobuf import descriptor
58from google.protobuf import descriptor_database
59
60
61class DescriptorPool(object):
62  """A collection of protobufs dynamically constructed by descriptor protos."""
63
64  def __init__(self, descriptor_db=None):
65    """Initializes a Pool of proto buffs.
66
67    The descriptor_db argument to the constructor is provided to allow
68    specialized file descriptor proto lookup code to be triggered on demand. An
69    example would be an implementation which will read and compile a file
70    specified in a call to FindFileByName() and not require the call to Add()
71    at all. Results from this database will be cached internally here as well.
72
73    Args:
74      descriptor_db: A secondary source of file descriptors.
75    """
76
77    self._internal_db = descriptor_database.DescriptorDatabase()
78    self._descriptor_db = descriptor_db
79    self._descriptors = {}
80    self._enum_descriptors = {}
81    self._file_descriptors = {}
82
83  def Add(self, file_desc_proto):
84    """Adds the FileDescriptorProto and its types to this pool.
85
86    Args:
87      file_desc_proto: The FileDescriptorProto to add.
88    """
89
90    self._internal_db.Add(file_desc_proto)
91
92  def FindFileByName(self, file_name):
93    """Gets a FileDescriptor by file name.
94
95    Args:
96      file_name: The path to the file to get a descriptor for.
97
98    Returns:
99      A FileDescriptor for the named file.
100
101    Raises:
102      KeyError: if the file can not be found in the pool.
103    """
104
105    try:
106      file_proto = self._internal_db.FindFileByName(file_name)
107    except KeyError as error:
108      if self._descriptor_db:
109        file_proto = self._descriptor_db.FindFileByName(file_name)
110      else:
111        raise error
112    if not file_proto:
113      raise KeyError('Cannot find a file named %s' % file_name)
114    return self._ConvertFileProtoToFileDescriptor(file_proto)
115
116  def FindFileContainingSymbol(self, symbol):
117    """Gets the FileDescriptor for the file containing the specified symbol.
118
119    Args:
120      symbol: The name of the symbol to search for.
121
122    Returns:
123      A FileDescriptor that contains the specified symbol.
124
125    Raises:
126      KeyError: if the file can not be found in the pool.
127    """
128
129    try:
130      file_proto = self._internal_db.FindFileContainingSymbol(symbol)
131    except KeyError as error:
132      if self._descriptor_db:
133        file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
134      else:
135        raise error
136    if not file_proto:
137      raise KeyError('Cannot find a file containing %s' % symbol)
138    return self._ConvertFileProtoToFileDescriptor(file_proto)
139
140  def FindMessageTypeByName(self, full_name):
141    """Loads the named descriptor from the pool.
142
143    Args:
144      full_name: The full name of the descriptor to load.
145
146    Returns:
147      The descriptor for the named type.
148    """
149
150    full_name = full_name.lstrip('.')  # fix inconsistent qualified name formats
151    if full_name not in self._descriptors:
152      self.FindFileContainingSymbol(full_name)
153    return self._descriptors[full_name]
154
155  def FindEnumTypeByName(self, full_name):
156    """Loads the named enum descriptor from the pool.
157
158    Args:
159      full_name: The full name of the enum descriptor to load.
160
161    Returns:
162      The enum descriptor for the named type.
163    """
164
165    full_name = full_name.lstrip('.')  # fix inconsistent qualified name formats
166    if full_name not in self._enum_descriptors:
167      self.FindFileContainingSymbol(full_name)
168    return self._enum_descriptors[full_name]
169
170  def _ConvertFileProtoToFileDescriptor(self, file_proto):
171    """Creates a FileDescriptor from a proto or returns a cached copy.
172
173    This method also has the side effect of loading all the symbols found in
174    the file into the appropriate dictionaries in the pool.
175
176    Args:
177      file_proto: The proto to convert.
178
179    Returns:
180      A FileDescriptor matching the passed in proto.
181    """
182
183    if file_proto.name not in self._file_descriptors:
184      file_descriptor = descriptor.FileDescriptor(
185          name=file_proto.name,
186          package=file_proto.package,
187          options=file_proto.options,
188          serialized_pb=file_proto.SerializeToString())
189      scope = {}
190      dependencies = list(self._GetDeps(file_proto))
191
192      for dependency in dependencies:
193        dep_desc = self.FindFileByName(dependency.name)
194        dep_proto = descriptor_pb2.FileDescriptorProto.FromString(
195            dep_desc.serialized_pb)
196        package = '.' + dep_proto.package
197        package_prefix = package + '.'
198
199        def _strip_package(symbol):
200          if symbol.startswith(package_prefix):
201            return symbol[len(package_prefix):]
202          return symbol
203
204        symbols = list(self._ExtractSymbols(dep_proto.message_type, package))
205        scope.update(symbols)
206        scope.update((_strip_package(k), v) for k, v in symbols)
207
208        symbols = list(self._ExtractEnums(dep_proto.enum_type, package))
209        scope.update(symbols)
210        scope.update((_strip_package(k), v) for k, v in symbols)
211
212      for message_type in file_proto.message_type:
213        message_desc = self._ConvertMessageDescriptor(
214            message_type, file_proto.package, file_descriptor, scope)
215        file_descriptor.message_types_by_name[message_desc.name] = message_desc
216      for enum_type in file_proto.enum_type:
217        self._ConvertEnumDescriptor(enum_type, file_proto.package,
218                                    file_descriptor, None, scope)
219      for desc_proto in self._ExtractMessages(file_proto.message_type):
220        self._SetFieldTypes(desc_proto, scope)
221
222      for desc_proto in file_proto.message_type:
223        desc = scope[desc_proto.name]
224        file_descriptor.message_types_by_name[desc_proto.name] = desc
225      self.Add(file_proto)
226      self._file_descriptors[file_proto.name] = file_descriptor
227
228    return self._file_descriptors[file_proto.name]
229
230  def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
231                                scope=None):
232    """Adds the proto to the pool in the specified package.
233
234    Args:
235      desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
236      package: The package the proto should be located in.
237      file_desc: The file containing this message.
238      scope: Dict mapping short and full symbols to message and enum types.
239
240    Returns:
241      The added descriptor.
242    """
243
244    if package:
245      desc_name = '.'.join((package, desc_proto.name))
246    else:
247      desc_name = desc_proto.name
248
249    if file_desc is None:
250      file_name = None
251    else:
252      file_name = file_desc.name
253
254    if scope is None:
255      scope = {}
256
257    nested = [
258        self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope)
259        for nested in desc_proto.nested_type]
260    enums = [
261        self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
262        for enum in desc_proto.enum_type]
263    fields = [self._MakeFieldDescriptor(field, desc_name, index)
264              for index, field in enumerate(desc_proto.field)]
265    extensions = [self._MakeFieldDescriptor(extension, desc_name, True)
266                  for index, extension in enumerate(desc_proto.extension)]
267    extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
268    if extension_ranges:
269      is_extendable = True
270    else:
271      is_extendable = False
272    desc = descriptor.Descriptor(
273        name=desc_proto.name,
274        full_name=desc_name,
275        filename=file_name,
276        containing_type=None,
277        fields=fields,
278        nested_types=nested,
279        enum_types=enums,
280        extensions=extensions,
281        options=desc_proto.options,
282        is_extendable=is_extendable,
283        extension_ranges=extension_ranges,
284        file=file_desc,
285        serialized_start=None,
286        serialized_end=None)
287    for nested in desc.nested_types:
288      nested.containing_type = desc
289    for enum in desc.enum_types:
290      enum.containing_type = desc
291    scope[desc_proto.name] = desc
292    scope['.' + desc_name] = desc
293    self._descriptors[desc_name] = desc
294    return desc
295
296  def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
297                             containing_type=None, scope=None):
298    """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
299
300    Args:
301      enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
302      package: Optional package name for the new message EnumDescriptor.
303      file_desc: The file containing the enum descriptor.
304      containing_type: The type containing this enum.
305      scope: Scope containing available types.
306
307    Returns:
308      The added descriptor
309    """
310
311    if package:
312      enum_name = '.'.join((package, enum_proto.name))
313    else:
314      enum_name = enum_proto.name
315
316    if file_desc is None:
317      file_name = None
318    else:
319      file_name = file_desc.name
320
321    values = [self._MakeEnumValueDescriptor(value, index)
322              for index, value in enumerate(enum_proto.value)]
323    desc = descriptor.EnumDescriptor(name=enum_proto.name,
324                                     full_name=enum_name,
325                                     filename=file_name,
326                                     file=file_desc,
327                                     values=values,
328                                     containing_type=containing_type,
329                                     options=enum_proto.options)
330    scope[enum_proto.name] = desc
331    scope['.%s' % enum_name] = desc
332    self._enum_descriptors[enum_name] = desc
333    return desc
334
335  def _MakeFieldDescriptor(self, field_proto, message_name, index,
336                           is_extension=False):
337    """Creates a field descriptor from a FieldDescriptorProto.
338
339    For message and enum type fields, this method will do a look up
340    in the pool for the appropriate descriptor for that type. If it
341    is unavailable, it will fall back to the _source function to
342    create it. If this type is still unavailable, construction will
343    fail.
344
345    Args:
346      field_proto: The proto describing the field.
347      message_name: The name of the containing message.
348      index: Index of the field
349      is_extension: Indication that this field is for an extension.
350
351    Returns:
352      An initialized FieldDescriptor object
353    """
354
355    if message_name:
356      full_name = '.'.join((message_name, field_proto.name))
357    else:
358      full_name = field_proto.name
359
360    return descriptor.FieldDescriptor(
361        name=field_proto.name,
362        full_name=full_name,
363        index=index,
364        number=field_proto.number,
365        type=field_proto.type,
366        cpp_type=None,
367        message_type=None,
368        enum_type=None,
369        containing_type=None,
370        label=field_proto.label,
371        has_default_value=False,
372        default_value=None,
373        is_extension=is_extension,
374        extension_scope=None,
375        options=field_proto.options)
376
377  def _SetFieldTypes(self, desc_proto, scope):
378    """Sets the field's type, cpp_type, message_type and enum_type.
379
380    Args:
381      desc_proto: The message descriptor to update.
382      scope: Enclosing scope of available types.
383    """
384
385    desc = scope[desc_proto.name]
386    for field_proto, field_desc in zip(desc_proto.field, desc.fields):
387      if field_proto.type_name:
388        type_name = field_proto.type_name
389        if type_name not in scope:
390          type_name = '.' + type_name
391        desc = scope[type_name]
392      else:
393        desc = None
394
395      if not field_proto.HasField('type'):
396        if isinstance(desc, descriptor.Descriptor):
397          field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
398        else:
399          field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
400
401      field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
402          field_proto.type)
403
404      if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
405          or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
406        field_desc.message_type = desc
407
408      if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
409        field_desc.enum_type = desc
410
411      if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
412        field_desc.has_default = False
413        field_desc.default_value = []
414      elif field_proto.HasField('default_value'):
415        field_desc.has_default = True
416        if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
417            field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
418          field_desc.default_value = float(field_proto.default_value)
419        elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
420          field_desc.default_value = field_proto.default_value
421        elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
422          field_desc.default_value = field_proto.default_value.lower() == 'true'
423        elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
424          field_desc.default_value = field_desc.enum_type.values_by_name[
425              field_proto.default_value].index
426        else:
427          field_desc.default_value = int(field_proto.default_value)
428      else:
429        field_desc.has_default = False
430        field_desc.default_value = None
431
432      field_desc.type = field_proto.type
433
434    for nested_type in desc_proto.nested_type:
435      self._SetFieldTypes(nested_type, scope)
436
437  def _MakeEnumValueDescriptor(self, value_proto, index):
438    """Creates a enum value descriptor object from a enum value proto.
439
440    Args:
441      value_proto: The proto describing the enum value.
442      index: The index of the enum value.
443
444    Returns:
445      An initialized EnumValueDescriptor object.
446    """
447
448    return descriptor.EnumValueDescriptor(
449        name=value_proto.name,
450        index=index,
451        number=value_proto.number,
452        options=value_proto.options,
453        type=None)
454
455  def _ExtractSymbols(self, desc_protos, package):
456    """Pulls out all the symbols from descriptor protos.
457
458    Args:
459      desc_protos: The protos to extract symbols from.
460      package: The package containing the descriptor type.
461    Yields:
462      A two element tuple of the type name and descriptor object.
463    """
464
465    for desc_proto in desc_protos:
466      if package:
467        message_name = '.'.join((package, desc_proto.name))
468      else:
469        message_name = desc_proto.name
470      message_desc = self.FindMessageTypeByName(message_name)
471      yield (message_name, message_desc)
472      for symbol in self._ExtractSymbols(desc_proto.nested_type, message_name):
473        yield symbol
474      for symbol in self._ExtractEnums(desc_proto.enum_type, message_name):
475        yield symbol
476
477  def _ExtractEnums(self, enum_protos, package):
478    """Pulls out all the symbols from enum protos.
479
480    Args:
481      enum_protos: The protos to extract symbols from.
482      package: The package containing the enum type.
483
484    Yields:
485      A two element tuple of the type name and enum descriptor object.
486    """
487
488    for enum_proto in enum_protos:
489      if package:
490        enum_name = '.'.join((package, enum_proto.name))
491      else:
492        enum_name = enum_proto.name
493      enum_desc = self.FindEnumTypeByName(enum_name)
494      yield (enum_name, enum_desc)
495
496  def _ExtractMessages(self, desc_protos):
497    """Pulls out all the message protos from descriptos.
498
499    Args:
500      desc_protos: The protos to extract symbols from.
501
502    Yields:
503      Descriptor protos.
504    """
505
506    for desc_proto in desc_protos:
507      yield desc_proto
508      for message in self._ExtractMessages(desc_proto.nested_type):
509        yield message
510
511  def _GetDeps(self, file_proto):
512    """Recursively finds dependencies for file protos.
513
514    Args:
515      file_proto: The proto to get dependencies from.
516
517    Yields:
518      Each direct and indirect dependency.
519    """
520
521    for dependency in file_proto.dependency:
522      dep_desc = self.FindFileByName(dependency)
523      dep_proto = descriptor_pb2.FileDescriptorProto.FromString(
524          dep_desc.serialized_pb)
525      yield dep_proto
526      for parent_dep in self._GetDeps(dep_proto):
527        yield parent_dep
528