1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for protorpc.generate_proto_test."""
19
20
21import os
22import shutil
23import cStringIO
24import sys
25import tempfile
26import unittest
27
28from protorpc import descriptor
29from protorpc import generate_proto
30from protorpc import test_util
31from protorpc import util
32
33
34class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
35                          test_util.TestCase):
36
37  MODULE = generate_proto
38
39
40class FormatProtoFileTest(test_util.TestCase):
41
42  def setUp(self):
43    self.file_descriptor = descriptor.FileDescriptor()
44    self.output = cStringIO.StringIO()
45
46  @property
47  def result(self):
48    return self.output.getvalue()
49
50  def MakeMessage(self, name='MyMessage', fields=[]):
51    message = descriptor.MessageDescriptor()
52    message.name = name
53    message.fields = fields
54
55    messages_list = getattr(self.file_descriptor, 'fields', [])
56    messages_list.append(message)
57    self.file_descriptor.message_types = messages_list
58
59  def testBlankPackage(self):
60    self.file_descriptor.package = None
61    generate_proto.format_proto_file(self.file_descriptor, self.output)
62    self.assertEquals('', self.result)
63
64  def testEmptyPackage(self):
65    self.file_descriptor.package = 'my_package'
66    generate_proto.format_proto_file(self.file_descriptor, self.output)
67    self.assertEquals('package my_package;\n', self.result)
68
69  def testSingleField(self):
70    field = descriptor.FieldDescriptor()
71    field.name = 'integer_field'
72    field.number = 1
73    field.label = descriptor.FieldDescriptor.Label.OPTIONAL
74    field.variant = descriptor.FieldDescriptor.Variant.INT64
75
76    self.MakeMessage(fields=[field])
77
78    generate_proto.format_proto_file(self.file_descriptor, self.output)
79    self.assertEquals('\n\n'
80                      'message MyMessage {\n'
81                      '  optional int64 integer_field = 1;\n'
82                      '}\n',
83                      self.result)
84
85  def testSingleFieldWithDefault(self):
86    field = descriptor.FieldDescriptor()
87    field.name = 'integer_field'
88    field.number = 1
89    field.label = descriptor.FieldDescriptor.Label.OPTIONAL
90    field.variant = descriptor.FieldDescriptor.Variant.INT64
91    field.default_value = '10'
92
93    self.MakeMessage(fields=[field])
94
95    generate_proto.format_proto_file(self.file_descriptor, self.output)
96    self.assertEquals('\n\n'
97                      'message MyMessage {\n'
98                      '  optional int64 integer_field = 1 [default=10];\n'
99                      '}\n',
100                      self.result)
101
102  def testRepeatedFieldWithDefault(self):
103    field = descriptor.FieldDescriptor()
104    field.name = 'integer_field'
105    field.number = 1
106    field.label = descriptor.FieldDescriptor.Label.REPEATED
107    field.variant = descriptor.FieldDescriptor.Variant.INT64
108    field.default_value = '[10, 20]'
109
110    self.MakeMessage(fields=[field])
111
112    generate_proto.format_proto_file(self.file_descriptor, self.output)
113    self.assertEquals('\n\n'
114                      'message MyMessage {\n'
115                      '  repeated int64 integer_field = 1;\n'
116                      '}\n',
117                      self.result)
118
119  def testSingleFieldWithDefaultString(self):
120    field = descriptor.FieldDescriptor()
121    field.name = 'string_field'
122    field.number = 1
123    field.label = descriptor.FieldDescriptor.Label.OPTIONAL
124    field.variant = descriptor.FieldDescriptor.Variant.STRING
125    field.default_value = 'hello'
126
127    self.MakeMessage(fields=[field])
128
129    generate_proto.format_proto_file(self.file_descriptor, self.output)
130    self.assertEquals('\n\n'
131                      'message MyMessage {\n'
132                      "  optional string string_field = 1 [default='hello'];\n"
133                      '}\n',
134                      self.result)
135
136  def testSingleFieldWithDefaultEmptyString(self):
137    field = descriptor.FieldDescriptor()
138    field.name = 'string_field'
139    field.number = 1
140    field.label = descriptor.FieldDescriptor.Label.OPTIONAL
141    field.variant = descriptor.FieldDescriptor.Variant.STRING
142    field.default_value = ''
143
144    self.MakeMessage(fields=[field])
145
146    generate_proto.format_proto_file(self.file_descriptor, self.output)
147    self.assertEquals('\n\n'
148                      'message MyMessage {\n'
149                      "  optional string string_field = 1 [default=''];\n"
150                      '}\n',
151                      self.result)
152
153  def testSingleFieldWithDefaultMessage(self):
154    field = descriptor.FieldDescriptor()
155    field.name = 'message_field'
156    field.number = 1
157    field.label = descriptor.FieldDescriptor.Label.OPTIONAL
158    field.variant = descriptor.FieldDescriptor.Variant.MESSAGE
159    field.type_name = 'MyNestedMessage'
160    field.default_value = 'not valid'
161
162    self.MakeMessage(fields=[field])
163
164    generate_proto.format_proto_file(self.file_descriptor, self.output)
165    self.assertEquals('\n\n'
166                      'message MyMessage {\n'
167                      "  optional MyNestedMessage message_field = 1;\n"
168                      '}\n',
169                      self.result)
170
171  def testSingleFieldWithDefaultEnum(self):
172    field = descriptor.FieldDescriptor()
173    field.name = 'enum_field'
174    field.number = 1
175    field.label = descriptor.FieldDescriptor.Label.OPTIONAL
176    field.variant = descriptor.FieldDescriptor.Variant.ENUM
177    field.type_name = 'my_package.MyEnum'
178    field.default_value = '17'
179
180    self.MakeMessage(fields=[field])
181
182    generate_proto.format_proto_file(self.file_descriptor, self.output)
183    self.assertEquals('\n\n'
184                      'message MyMessage {\n'
185                      "  optional my_package.MyEnum enum_field = 1 "
186                      "[default=17];\n"
187                      '}\n',
188                      self.result)
189
190
191def main():
192  unittest.main()
193
194
195if __name__ == '__main__':
196  main()
197
198