1// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc.  All rights reserved.
3// https://developers.google.com/protocol-buffers/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are
7// met:
8//
9//     * Redistributions of source code must retain the above copyright
10// notice, this list of conditions and the following disclaimer.
11//     * Redistributions in binary form must reproduce the above
12// copyright notice, this list of conditions and the following disclaimer
13// in the documentation and/or other materials provided with the
14// distribution.
15//     * Neither the name of Google Inc. nor the names of its
16// contributors may be used to endorse or promote products derived from
17// this software without specific prior written permission.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31// Author: haberman@google.com (Josh Haberman)
32
33#include <google/protobuf/pyext/map_container.h>
34
35#include <memory>
36#ifndef _SHARED_PTR_H
37#include <google/protobuf/stubs/shared_ptr.h>
38#endif
39
40#include <google/protobuf/stubs/logging.h>
41#include <google/protobuf/stubs/common.h>
42#include <google/protobuf/stubs/scoped_ptr.h>
43#include <google/protobuf/map_field.h>
44#include <google/protobuf/map.h>
45#include <google/protobuf/message.h>
46#include <google/protobuf/pyext/message.h>
47#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
48
49#if PY_MAJOR_VERSION >= 3
50  #define PyInt_FromLong PyLong_FromLong
51  #define PyInt_FromSize_t PyLong_FromSize_t
52#endif
53
54namespace google {
55namespace protobuf {
56namespace python {
57
58// Functions that need access to map reflection functionality.
59// They need to be contained in this class because it is friended.
60class MapReflectionFriend {
61 public:
62  // Methods that are in common between the map types.
63  static PyObject* Contains(PyObject* _self, PyObject* key);
64  static Py_ssize_t Length(PyObject* _self);
65  static PyObject* GetIterator(PyObject *_self);
66  static PyObject* IterNext(PyObject* _self);
67
68  // Methods that differ between the map types.
69  static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
70  static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
71  static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
72  static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
73};
74
75struct MapIterator {
76  PyObject_HEAD;
77
78  google::protobuf::scoped_ptr< ::google::protobuf::MapIterator> iter;
79
80  // A pointer back to the container, so we can notice changes to the version.
81  // We own a ref on this.
82  MapContainer* container;
83
84  // We need to keep a ref on the Message* too, because
85  // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
86  // the ref on container (above) would guarantee outlive semantics.  However in
87  // the case of ClearField(), InitializeAndCopyToParentContainer() resets the
88  // message pointer (and the owner) to a different message, a copy of the
89  // original.  But our iterator still points to the original, which could now
90  // get deleted before us.
91  //
92  // To prevent this, we ensure that the Message will always stay alive as long
93  // as this iterator does.  This is solely for the benefit of the MapIterator
94  // destructor -- we should never actually access the iterator in this state
95  // except to delete it.
96  shared_ptr<Message> owner;
97
98  // The version of the map when we took the iterator to it.
99  //
100  // We store this so that if the map is modified during iteration we can throw
101  // an error.
102  uint64 version;
103
104  // True if the container is empty.  We signal this separately to avoid calling
105  // any of the iteration methods, which are non-const.
106  bool empty;
107};
108
109Message* MapContainer::GetMutableMessage() {
110  cmessage::AssureWritable(parent);
111  return const_cast<Message*>(message);
112}
113
114// Consumes a reference on the Python string object.
115static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
116  char *value;
117  Py_ssize_t value_len;
118
119  if (!py_string) {
120    return false;
121  }
122  if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
123    Py_DECREF(py_string);
124    return false;
125  } else {
126    stl_string->assign(value, value_len);
127    Py_DECREF(py_string);
128    return true;
129  }
130}
131
132static bool PythonToMapKey(PyObject* obj,
133                           const FieldDescriptor* field_descriptor,
134                           MapKey* key) {
135  switch (field_descriptor->cpp_type()) {
136    case FieldDescriptor::CPPTYPE_INT32: {
137      GOOGLE_CHECK_GET_INT32(obj, value, false);
138      key->SetInt32Value(value);
139      break;
140    }
141    case FieldDescriptor::CPPTYPE_INT64: {
142      GOOGLE_CHECK_GET_INT64(obj, value, false);
143      key->SetInt64Value(value);
144      break;
145    }
146    case FieldDescriptor::CPPTYPE_UINT32: {
147      GOOGLE_CHECK_GET_UINT32(obj, value, false);
148      key->SetUInt32Value(value);
149      break;
150    }
151    case FieldDescriptor::CPPTYPE_UINT64: {
152      GOOGLE_CHECK_GET_UINT64(obj, value, false);
153      key->SetUInt64Value(value);
154      break;
155    }
156    case FieldDescriptor::CPPTYPE_BOOL: {
157      GOOGLE_CHECK_GET_BOOL(obj, value, false);
158      key->SetBoolValue(value);
159      break;
160    }
161    case FieldDescriptor::CPPTYPE_STRING: {
162      string str;
163      if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
164        return false;
165      }
166      key->SetStringValue(str);
167      break;
168    }
169    default:
170      PyErr_Format(
171          PyExc_SystemError, "Type %d cannot be a map key",
172          field_descriptor->cpp_type());
173      return false;
174  }
175  return true;
176}
177
178static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
179                                const MapKey& key) {
180  switch (field_descriptor->cpp_type()) {
181    case FieldDescriptor::CPPTYPE_INT32:
182      return PyInt_FromLong(key.GetInt32Value());
183    case FieldDescriptor::CPPTYPE_INT64:
184      return PyLong_FromLongLong(key.GetInt64Value());
185    case FieldDescriptor::CPPTYPE_UINT32:
186      return PyInt_FromSize_t(key.GetUInt32Value());
187    case FieldDescriptor::CPPTYPE_UINT64:
188      return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
189    case FieldDescriptor::CPPTYPE_BOOL:
190      return PyBool_FromLong(key.GetBoolValue());
191    case FieldDescriptor::CPPTYPE_STRING:
192      return ToStringObject(field_descriptor, key.GetStringValue());
193    default:
194      PyErr_Format(
195          PyExc_SystemError, "Couldn't convert type %d to value",
196          field_descriptor->cpp_type());
197      return NULL;
198  }
199}
200
201// This is only used for ScalarMap, so we don't need to handle the
202// CPPTYPE_MESSAGE case.
203PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
204                              MapValueRef* value) {
205  switch (field_descriptor->cpp_type()) {
206    case FieldDescriptor::CPPTYPE_INT32:
207      return PyInt_FromLong(value->GetInt32Value());
208    case FieldDescriptor::CPPTYPE_INT64:
209      return PyLong_FromLongLong(value->GetInt64Value());
210    case FieldDescriptor::CPPTYPE_UINT32:
211      return PyInt_FromSize_t(value->GetUInt32Value());
212    case FieldDescriptor::CPPTYPE_UINT64:
213      return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
214    case FieldDescriptor::CPPTYPE_FLOAT:
215      return PyFloat_FromDouble(value->GetFloatValue());
216    case FieldDescriptor::CPPTYPE_DOUBLE:
217      return PyFloat_FromDouble(value->GetDoubleValue());
218    case FieldDescriptor::CPPTYPE_BOOL:
219      return PyBool_FromLong(value->GetBoolValue());
220    case FieldDescriptor::CPPTYPE_STRING:
221      return ToStringObject(field_descriptor, value->GetStringValue());
222    case FieldDescriptor::CPPTYPE_ENUM:
223      return PyInt_FromLong(value->GetEnumValue());
224    default:
225      PyErr_Format(
226          PyExc_SystemError, "Couldn't convert type %d to value",
227          field_descriptor->cpp_type());
228      return NULL;
229  }
230}
231
232// This is only used for ScalarMap, so we don't need to handle the
233// CPPTYPE_MESSAGE case.
234static bool PythonToMapValueRef(PyObject* obj,
235                                const FieldDescriptor* field_descriptor,
236                                bool allow_unknown_enum_values,
237                                MapValueRef* value_ref) {
238  switch (field_descriptor->cpp_type()) {
239    case FieldDescriptor::CPPTYPE_INT32: {
240      GOOGLE_CHECK_GET_INT32(obj, value, false);
241      value_ref->SetInt32Value(value);
242      return true;
243    }
244    case FieldDescriptor::CPPTYPE_INT64: {
245      GOOGLE_CHECK_GET_INT64(obj, value, false);
246      value_ref->SetInt64Value(value);
247      return true;
248    }
249    case FieldDescriptor::CPPTYPE_UINT32: {
250      GOOGLE_CHECK_GET_UINT32(obj, value, false);
251      value_ref->SetUInt32Value(value);
252      return true;
253    }
254    case FieldDescriptor::CPPTYPE_UINT64: {
255      GOOGLE_CHECK_GET_UINT64(obj, value, false);
256      value_ref->SetUInt64Value(value);
257      return true;
258    }
259    case FieldDescriptor::CPPTYPE_FLOAT: {
260      GOOGLE_CHECK_GET_FLOAT(obj, value, false);
261      value_ref->SetFloatValue(value);
262      return true;
263    }
264    case FieldDescriptor::CPPTYPE_DOUBLE: {
265      GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
266      value_ref->SetDoubleValue(value);
267      return true;
268    }
269    case FieldDescriptor::CPPTYPE_BOOL: {
270      GOOGLE_CHECK_GET_BOOL(obj, value, false);
271      value_ref->SetBoolValue(value);
272      return true;;
273    }
274    case FieldDescriptor::CPPTYPE_STRING: {
275      string str;
276      if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
277        return false;
278      }
279      value_ref->SetStringValue(str);
280      return true;
281    }
282    case FieldDescriptor::CPPTYPE_ENUM: {
283      GOOGLE_CHECK_GET_INT32(obj, value, false);
284      if (allow_unknown_enum_values) {
285        value_ref->SetEnumValue(value);
286        return true;
287      } else {
288        const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
289        const EnumValueDescriptor* enum_value =
290            enum_descriptor->FindValueByNumber(value);
291        if (enum_value != NULL) {
292          value_ref->SetEnumValue(value);
293          return true;
294        } else {
295          PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
296          return false;
297        }
298      }
299      break;
300    }
301    default:
302      PyErr_Format(
303          PyExc_SystemError, "Setting value to a field of unknown type %d",
304          field_descriptor->cpp_type());
305      return false;
306  }
307}
308
309// Map methods common to ScalarMap and MessageMap //////////////////////////////
310
311static MapContainer* GetMap(PyObject* obj) {
312  return reinterpret_cast<MapContainer*>(obj);
313}
314
315Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
316  MapContainer* self = GetMap(_self);
317  const google::protobuf::Message* message = self->message;
318  return message->GetReflection()->MapSize(*message,
319                                           self->parent_field_descriptor);
320}
321
322PyObject* Clear(PyObject* _self) {
323  MapContainer* self = GetMap(_self);
324  Message* message = self->GetMutableMessage();
325  const Reflection* reflection = message->GetReflection();
326
327  reflection->ClearField(message, self->parent_field_descriptor);
328
329  Py_RETURN_NONE;
330}
331
332PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
333  MapContainer* self = GetMap(_self);
334
335  const Message* message = self->message;
336  const Reflection* reflection = message->GetReflection();
337  MapKey map_key;
338
339  if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
340    return NULL;
341  }
342
343  if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
344                                 map_key)) {
345    Py_RETURN_TRUE;
346  } else {
347    Py_RETURN_FALSE;
348  }
349}
350
351// Initializes the underlying Message object of "to" so it becomes a new parent
352// repeated scalar, and copies all the values from "from" to it. A child scalar
353// container can be released by passing it as both from and to (e.g. making it
354// the recipient of the new parent message and copying the values from itself).
355static int InitializeAndCopyToParentContainer(MapContainer* from,
356                                              MapContainer* to) {
357  // For now we require from == to, re-evaluate if we want to support deep copy
358  // as in repeated_scalar_container.cc.
359  GOOGLE_DCHECK(from == to);
360  Message* new_message = from->message->New();
361
362  if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) {
363    // A somewhat roundabout way of copying just one field from old_message to
364    // new_message.  This is the best we can do with what Reflection gives us.
365    Message* mutable_old = from->GetMutableMessage();
366    vector<const FieldDescriptor*> fields;
367    fields.push_back(from->parent_field_descriptor);
368
369    // Move the map field into the new message.
370    mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields);
371
372    // If/when we support from != to, this will be required also to copy the
373    // map field back into the existing message:
374    // mutable_old->MergeFrom(*new_message);
375  }
376
377  // If from == to this could delete old_message.
378  to->owner.reset(new_message);
379
380  to->parent = NULL;
381  to->parent_field_descriptor = from->parent_field_descriptor;
382  to->message = new_message;
383
384  // Invalidate iterators, since they point to the old copy of the field.
385  to->version++;
386
387  return 0;
388}
389
390int MapContainer::Release() {
391  return InitializeAndCopyToParentContainer(this, this);
392}
393
394
395// ScalarMap ///////////////////////////////////////////////////////////////////
396
397PyObject *NewScalarMapContainer(
398    CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
399  if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
400    return NULL;
401  }
402
403#if PY_MAJOR_VERSION >= 3
404  ScopedPyObjectPtr obj(PyType_GenericAlloc(
405        reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
406#else
407  ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
408#endif
409  if (obj.get() == NULL) {
410    return PyErr_Format(PyExc_RuntimeError,
411                        "Could not allocate new container.");
412  }
413
414  MapContainer* self = GetMap(obj.get());
415
416  self->message = parent->message;
417  self->parent = parent;
418  self->parent_field_descriptor = parent_field_descriptor;
419  self->owner = parent->owner;
420  self->version = 0;
421
422  self->key_field_descriptor =
423      parent_field_descriptor->message_type()->FindFieldByName("key");
424  self->value_field_descriptor =
425      parent_field_descriptor->message_type()->FindFieldByName("value");
426
427  if (self->key_field_descriptor == NULL ||
428      self->value_field_descriptor == NULL) {
429    return PyErr_Format(PyExc_KeyError,
430                        "Map entry descriptor did not have key/value fields");
431  }
432
433  return obj.release();
434}
435
436PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
437                                                PyObject* key) {
438  MapContainer* self = GetMap(_self);
439
440  Message* message = self->GetMutableMessage();
441  const Reflection* reflection = message->GetReflection();
442  MapKey map_key;
443  MapValueRef value;
444
445  if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
446    return NULL;
447  }
448
449  if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
450                                         map_key, &value)) {
451    self->version++;
452  }
453
454  return MapValueRefToPython(self->value_field_descriptor, &value);
455}
456
457int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
458                                          PyObject* v) {
459  MapContainer* self = GetMap(_self);
460
461  Message* message = self->GetMutableMessage();
462  const Reflection* reflection = message->GetReflection();
463  MapKey map_key;
464  MapValueRef value;
465
466  if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
467    return -1;
468  }
469
470  self->version++;
471
472  if (v) {
473    // Set item to v.
474    reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
475                                       map_key, &value);
476
477    return PythonToMapValueRef(v, self->value_field_descriptor,
478                               reflection->SupportsUnknownEnumValues(), &value)
479               ? 0
480               : -1;
481  } else {
482    // Delete key from map.
483    if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
484                                   map_key)) {
485      return 0;
486    } else {
487      PyErr_Format(PyExc_KeyError, "Key not present in map");
488      return -1;
489    }
490  }
491}
492
493static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
494  PyObject* key;
495  PyObject* default_value = NULL;
496  if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
497    return NULL;
498  }
499
500  ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
501  if (is_present.get() == NULL) {
502    return NULL;
503  }
504
505  if (PyObject_IsTrue(is_present.get())) {
506    return MapReflectionFriend::ScalarMapGetItem(self, key);
507  } else {
508    if (default_value != NULL) {
509      Py_INCREF(default_value);
510      return default_value;
511    } else {
512      Py_RETURN_NONE;
513    }
514  }
515}
516
517static void ScalarMapDealloc(PyObject* _self) {
518  MapContainer* self = GetMap(_self);
519  self->owner.reset();
520  Py_TYPE(_self)->tp_free(_self);
521}
522
523static PyMethodDef ScalarMapMethods[] = {
524  { "__contains__", MapReflectionFriend::Contains, METH_O,
525    "Tests whether a key is a member of the map." },
526  { "clear", (PyCFunction)Clear, METH_NOARGS,
527    "Removes all elements from the map." },
528  { "get", ScalarMapGet, METH_VARARGS,
529    "Gets the value for the given key if present, or otherwise a default" },
530  /*
531  { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
532    "Makes a deep copy of the class." },
533  { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
534    "Outputs picklable representation of the repeated field." },
535  */
536  {NULL, NULL},
537};
538
539#if PY_MAJOR_VERSION >= 3
540  static PyType_Slot ScalarMapContainer_Type_slots[] = {
541      {Py_tp_dealloc, (void *)ScalarMapDealloc},
542      {Py_mp_length, (void *)MapReflectionFriend::Length},
543      {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
544      {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
545      {Py_tp_methods, (void *)ScalarMapMethods},
546      {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
547      {0, 0},
548  };
549
550  PyType_Spec ScalarMapContainer_Type_spec = {
551      FULL_MODULE_NAME ".ScalarMapContainer",
552      sizeof(MapContainer),
553      0,
554      Py_TPFLAGS_DEFAULT,
555      ScalarMapContainer_Type_slots
556  };
557  PyObject *ScalarMapContainer_Type;
558#else
559  static PyMappingMethods ScalarMapMappingMethods = {
560    MapReflectionFriend::Length,             // mp_length
561    MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
562    MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
563  };
564
565  PyTypeObject ScalarMapContainer_Type = {
566    PyVarObject_HEAD_INIT(&PyType_Type, 0)
567    FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
568    sizeof(MapContainer),                //  tp_basicsize
569    0,                                   //  tp_itemsize
570    ScalarMapDealloc,                    //  tp_dealloc
571    0,                                   //  tp_print
572    0,                                   //  tp_getattr
573    0,                                   //  tp_setattr
574    0,                                   //  tp_compare
575    0,                                   //  tp_repr
576    0,                                   //  tp_as_number
577    0,                                   //  tp_as_sequence
578    &ScalarMapMappingMethods,            //  tp_as_mapping
579    0,                                   //  tp_hash
580    0,                                   //  tp_call
581    0,                                   //  tp_str
582    0,                                   //  tp_getattro
583    0,                                   //  tp_setattro
584    0,                                   //  tp_as_buffer
585    Py_TPFLAGS_DEFAULT,                  //  tp_flags
586    "A scalar map container",            //  tp_doc
587    0,                                   //  tp_traverse
588    0,                                   //  tp_clear
589    0,                                   //  tp_richcompare
590    0,                                   //  tp_weaklistoffset
591    MapReflectionFriend::GetIterator,    //  tp_iter
592    0,                                   //  tp_iternext
593    ScalarMapMethods,                    //  tp_methods
594    0,                                   //  tp_members
595    0,                                   //  tp_getset
596    0,                                   //  tp_base
597    0,                                   //  tp_dict
598    0,                                   //  tp_descr_get
599    0,                                   //  tp_descr_set
600    0,                                   //  tp_dictoffset
601    0,                                   //  tp_init
602  };
603#endif
604
605
606// MessageMap //////////////////////////////////////////////////////////////////
607
608static MessageMapContainer* GetMessageMap(PyObject* obj) {
609  return reinterpret_cast<MessageMapContainer*>(obj);
610}
611
612static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
613  // Get or create the CMessage object corresponding to this message.
614  ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
615  PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
616
617  if (ret == NULL) {
618    CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class);
619    ret = reinterpret_cast<PyObject*>(cmsg);
620
621    if (cmsg == NULL) {
622      return NULL;
623    }
624    cmsg->owner = self->owner;
625    cmsg->message = message;
626    cmsg->parent = self->parent;
627
628    if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
629      Py_DECREF(ret);
630      return NULL;
631    }
632  } else {
633    Py_INCREF(ret);
634  }
635
636  return ret;
637}
638
639PyObject* NewMessageMapContainer(
640    CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
641    CMessageClass* message_class) {
642  if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
643    return NULL;
644  }
645
646#if PY_MAJOR_VERSION >= 3
647  PyObject* obj = PyType_GenericAlloc(
648        reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
649#else
650  PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
651#endif
652  if (obj == NULL) {
653    return PyErr_Format(PyExc_RuntimeError,
654                        "Could not allocate new container.");
655  }
656
657  MessageMapContainer* self = GetMessageMap(obj);
658
659  self->message = parent->message;
660  self->parent = parent;
661  self->parent_field_descriptor = parent_field_descriptor;
662  self->owner = parent->owner;
663  self->version = 0;
664
665  self->key_field_descriptor =
666      parent_field_descriptor->message_type()->FindFieldByName("key");
667  self->value_field_descriptor =
668      parent_field_descriptor->message_type()->FindFieldByName("value");
669
670  self->message_dict = PyDict_New();
671  if (self->message_dict == NULL) {
672    return PyErr_Format(PyExc_RuntimeError,
673                        "Could not allocate message dict.");
674  }
675
676  Py_INCREF(message_class);
677  self->message_class = message_class;
678
679  if (self->key_field_descriptor == NULL ||
680      self->value_field_descriptor == NULL) {
681    Py_DECREF(obj);
682    return PyErr_Format(PyExc_KeyError,
683                        "Map entry descriptor did not have key/value fields");
684  }
685
686  return obj;
687}
688
689int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
690                                           PyObject* v) {
691  if (v) {
692    PyErr_Format(PyExc_ValueError,
693                 "Direct assignment of submessage not allowed");
694    return -1;
695  }
696
697  // Now we know that this is a delete, not a set.
698
699  MessageMapContainer* self = GetMessageMap(_self);
700  Message* message = self->GetMutableMessage();
701  const Reflection* reflection = message->GetReflection();
702  MapKey map_key;
703  MapValueRef value;
704
705  self->version++;
706
707  if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
708    return -1;
709  }
710
711  // Delete key from map.
712  if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
713                                 map_key)) {
714    return 0;
715  } else {
716    PyErr_Format(PyExc_KeyError, "Key not present in map");
717    return -1;
718  }
719}
720
721PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
722                                                 PyObject* key) {
723  MessageMapContainer* self = GetMessageMap(_self);
724
725  Message* message = self->GetMutableMessage();
726  const Reflection* reflection = message->GetReflection();
727  MapKey map_key;
728  MapValueRef value;
729
730  if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
731    return NULL;
732  }
733
734  if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
735                                         map_key, &value)) {
736    self->version++;
737  }
738
739  return GetCMessage(self, value.MutableMessageValue());
740}
741
742PyObject* MessageMapGet(PyObject* self, PyObject* args) {
743  PyObject* key;
744  PyObject* default_value = NULL;
745  if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
746    return NULL;
747  }
748
749  ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
750  if (is_present.get() == NULL) {
751    return NULL;
752  }
753
754  if (PyObject_IsTrue(is_present.get())) {
755    return MapReflectionFriend::MessageMapGetItem(self, key);
756  } else {
757    if (default_value != NULL) {
758      Py_INCREF(default_value);
759      return default_value;
760    } else {
761      Py_RETURN_NONE;
762    }
763  }
764}
765
766static void MessageMapDealloc(PyObject* _self) {
767  MessageMapContainer* self = GetMessageMap(_self);
768  self->owner.reset();
769  Py_DECREF(self->message_dict);
770  Py_DECREF(self->message_class);
771  Py_TYPE(_self)->tp_free(_self);
772}
773
774static PyMethodDef MessageMapMethods[] = {
775  { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
776    "Tests whether the map contains this element."},
777  { "clear", (PyCFunction)Clear, METH_NOARGS,
778    "Removes all elements from the map."},
779  { "get", MessageMapGet, METH_VARARGS,
780    "Gets the value for the given key if present, or otherwise a default" },
781  { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
782    "Alias for getitem, useful to make explicit that the map is mutated." },
783  /*
784  { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
785    "Makes a deep copy of the class." },
786  { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
787    "Outputs picklable representation of the repeated field." },
788  */
789  {NULL, NULL},
790};
791
792#if PY_MAJOR_VERSION >= 3
793  static PyType_Slot MessageMapContainer_Type_slots[] = {
794      {Py_tp_dealloc, (void *)MessageMapDealloc},
795      {Py_mp_length, (void *)MapReflectionFriend::Length},
796      {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
797      {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
798      {Py_tp_methods, (void *)MessageMapMethods},
799      {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
800      {0, 0}
801  };
802
803  PyType_Spec MessageMapContainer_Type_spec = {
804      FULL_MODULE_NAME ".MessageMapContainer",
805      sizeof(MessageMapContainer),
806      0,
807      Py_TPFLAGS_DEFAULT,
808      MessageMapContainer_Type_slots
809  };
810
811  PyObject *MessageMapContainer_Type;
812#else
813  static PyMappingMethods MessageMapMappingMethods = {
814    MapReflectionFriend::Length,              // mp_length
815    MapReflectionFriend::MessageMapGetItem,   // mp_subscript
816    MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
817  };
818
819  PyTypeObject MessageMapContainer_Type = {
820    PyVarObject_HEAD_INIT(&PyType_Type, 0)
821    FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
822    sizeof(MessageMapContainer),         //  tp_basicsize
823    0,                                   //  tp_itemsize
824    MessageMapDealloc,                   //  tp_dealloc
825    0,                                   //  tp_print
826    0,                                   //  tp_getattr
827    0,                                   //  tp_setattr
828    0,                                   //  tp_compare
829    0,                                   //  tp_repr
830    0,                                   //  tp_as_number
831    0,                                   //  tp_as_sequence
832    &MessageMapMappingMethods,           //  tp_as_mapping
833    0,                                   //  tp_hash
834    0,                                   //  tp_call
835    0,                                   //  tp_str
836    0,                                   //  tp_getattro
837    0,                                   //  tp_setattro
838    0,                                   //  tp_as_buffer
839    Py_TPFLAGS_DEFAULT,                  //  tp_flags
840    "A map container for message",       //  tp_doc
841    0,                                   //  tp_traverse
842    0,                                   //  tp_clear
843    0,                                   //  tp_richcompare
844    0,                                   //  tp_weaklistoffset
845    MapReflectionFriend::GetIterator,    //  tp_iter
846    0,                                   //  tp_iternext
847    MessageMapMethods,                   //  tp_methods
848    0,                                   //  tp_members
849    0,                                   //  tp_getset
850    0,                                   //  tp_base
851    0,                                   //  tp_dict
852    0,                                   //  tp_descr_get
853    0,                                   //  tp_descr_set
854    0,                                   //  tp_dictoffset
855    0,                                   //  tp_init
856  };
857#endif
858
859// MapIterator /////////////////////////////////////////////////////////////////
860
861static MapIterator* GetIter(PyObject* obj) {
862  return reinterpret_cast<MapIterator*>(obj);
863}
864
865PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
866  MapContainer* self = GetMap(_self);
867
868  ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
869  if (obj == NULL) {
870    return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
871  }
872
873  MapIterator* iter = GetIter(obj.get());
874
875  Py_INCREF(self);
876  iter->container = self;
877  iter->version = self->version;
878  iter->owner = self->owner;
879
880  if (MapReflectionFriend::Length(_self) > 0) {
881    Message* message = self->GetMutableMessage();
882    const Reflection* reflection = message->GetReflection();
883
884    iter->iter.reset(new ::google::protobuf::MapIterator(
885        reflection->MapBegin(message, self->parent_field_descriptor)));
886  }
887
888  return obj.release();
889}
890
891PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
892  MapIterator* self = GetIter(_self);
893
894  // This won't catch mutations to the map performed by MergeFrom(); no easy way
895  // to address that.
896  if (self->version != self->container->version) {
897    return PyErr_Format(PyExc_RuntimeError,
898                        "Map modified during iteration.");
899  }
900
901  if (self->iter.get() == NULL) {
902    return NULL;
903  }
904
905  Message* message = self->container->GetMutableMessage();
906  const Reflection* reflection = message->GetReflection();
907
908  if (*self->iter ==
909      reflection->MapEnd(message, self->container->parent_field_descriptor)) {
910    return NULL;
911  }
912
913  PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
914                                 self->iter->GetKey());
915
916  ++(*self->iter);
917
918  return ret;
919}
920
921static void DeallocMapIterator(PyObject* _self) {
922  MapIterator* self = GetIter(_self);
923  self->iter.reset();
924  self->owner.reset();
925  Py_XDECREF(self->container);
926  Py_TYPE(_self)->tp_free(_self);
927}
928
929PyTypeObject MapIterator_Type = {
930  PyVarObject_HEAD_INIT(&PyType_Type, 0)
931  FULL_MODULE_NAME ".MapIterator",     //  tp_name
932  sizeof(MapIterator),                 //  tp_basicsize
933  0,                                   //  tp_itemsize
934  DeallocMapIterator,                  //  tp_dealloc
935  0,                                   //  tp_print
936  0,                                   //  tp_getattr
937  0,                                   //  tp_setattr
938  0,                                   //  tp_compare
939  0,                                   //  tp_repr
940  0,                                   //  tp_as_number
941  0,                                   //  tp_as_sequence
942  0,                                   //  tp_as_mapping
943  0,                                   //  tp_hash
944  0,                                   //  tp_call
945  0,                                   //  tp_str
946  0,                                   //  tp_getattro
947  0,                                   //  tp_setattro
948  0,                                   //  tp_as_buffer
949  Py_TPFLAGS_DEFAULT,                  //  tp_flags
950  "A scalar map iterator",             //  tp_doc
951  0,                                   //  tp_traverse
952  0,                                   //  tp_clear
953  0,                                   //  tp_richcompare
954  0,                                   //  tp_weaklistoffset
955  PyObject_SelfIter,                   //  tp_iter
956  MapReflectionFriend::IterNext,       //  tp_iternext
957  0,                                   //  tp_methods
958  0,                                   //  tp_members
959  0,                                   //  tp_getset
960  0,                                   //  tp_base
961  0,                                   //  tp_dict
962  0,                                   //  tp_descr_get
963  0,                                   //  tp_descr_set
964  0,                                   //  tp_dictoffset
965  0,                                   //  tp_init
966};
967
968}  // namespace python
969}  // namespace protobuf
970}  // namespace google
971