toco_python_api.cc revision 0b15439f8f0f2d4755587f4096c3ea04cb199d23
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <string>
16#include <vector>
17#include "tensorflow/core/platform/logging.h"
18
19#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
20#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
21#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
22#include "tensorflow/contrib/lite/toco/toco_port.h"
23#include "tensorflow/contrib/lite/toco/toco_tooling.h"
24#include "tensorflow/contrib/lite/toco/toco_types.h"
25
26namespace toco {
27
28#if PY_MAJOR_VERSION >= 3
29#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize
30#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize
31#else
32#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize
33#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize
34#endif
35
36// NOTE(aselle): We are using raw PyObject's here because we want to make
37// sure we input and output bytes rather than unicode strings for Python3.
38PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
39                      PyObject* toco_flags_proto_txt_raw,
40                      PyObject* input_contents_txt_raw) {
41  // Use Python C API to validate and convert arguments. In py3 (bytes),
42  // in py2 (str).
43  auto ConvertArg = [&](PyObject* obj, bool* error) {
44    char* buf;
45    Py_ssize_t len;
46    if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) {
47      *error = true;
48      return std::string();
49    } else {
50      *error = false;
51      return std::string(buf, len);
52    }
53  };
54
55  bool error;
56  std::string model_flags_proto_txt =
57      ConvertArg(model_flags_proto_txt_raw, &error);
58  if (error) return nullptr;
59  std::string toco_flags_proto_txt =
60      ConvertArg(toco_flags_proto_txt_raw, &error);
61  if (error) return nullptr;
62  std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
63  if (error) return nullptr;
64
65  // Use toco to produce new outputs
66  toco::ModelFlags model_flags;
67  if (!model_flags.ParseFromString(model_flags_proto_txt)) {
68    LOG(FATAL) << "Model proto failed to parse." << std::endl;
69  }
70  toco::TocoFlags toco_flags;
71  if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
72    LOG(FATAL) << "Toco proto failed to parse." << std::endl;
73  }
74  std::unique_ptr<toco::Model> model =
75      toco::Import(toco_flags, model_flags, input_contents_txt);
76  toco::Transform(toco_flags, model.get());
77  string output_file_contents_txt;
78  Export(toco_flags, *model, &output_file_contents_txt);
79
80  // Convert arguments back to byte (py3) or str (py2)
81  return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
82                                   output_file_contents_txt.size());
83}
84
85}  // namespace toco
86