toco_python_api.cc revision 0b15439f8f0f2d4755587f4096c3ea04cb199d23
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <string> 16#include <vector> 17#include "tensorflow/core/platform/logging.h" 18 19#include "tensorflow/contrib/lite/toco/model_flags.pb.h" 20#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" 21#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" 22#include "tensorflow/contrib/lite/toco/toco_port.h" 23#include "tensorflow/contrib/lite/toco/toco_tooling.h" 24#include "tensorflow/contrib/lite/toco/toco_types.h" 25 26namespace toco { 27 28#if PY_MAJOR_VERSION >= 3 29#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize 30#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize 31#else 32#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize 33#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize 34#endif 35 36// NOTE(aselle): We are using raw PyObject's here because we want to make 37// sure we input and output bytes rather than unicode strings for Python3. 38PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, 39 PyObject* toco_flags_proto_txt_raw, 40 PyObject* input_contents_txt_raw) { 41 // Use Python C API to validate and convert arguments. In py3 (bytes), 42 // in py2 (str). 43 auto ConvertArg = [&](PyObject* obj, bool* error) { 44 char* buf; 45 Py_ssize_t len; 46 if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) { 47 *error = true; 48 return std::string(); 49 } else { 50 *error = false; 51 return std::string(buf, len); 52 } 53 }; 54 55 bool error; 56 std::string model_flags_proto_txt = 57 ConvertArg(model_flags_proto_txt_raw, &error); 58 if (error) return nullptr; 59 std::string toco_flags_proto_txt = 60 ConvertArg(toco_flags_proto_txt_raw, &error); 61 if (error) return nullptr; 62 std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); 63 if (error) return nullptr; 64 65 // Use toco to produce new outputs 66 toco::ModelFlags model_flags; 67 if (!model_flags.ParseFromString(model_flags_proto_txt)) { 68 LOG(FATAL) << "Model proto failed to parse." << std::endl; 69 } 70 toco::TocoFlags toco_flags; 71 if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { 72 LOG(FATAL) << "Toco proto failed to parse." << std::endl; 73 } 74 std::unique_ptr<toco::Model> model = 75 toco::Import(toco_flags, model_flags, input_contents_txt); 76 toco::Transform(toco_flags, model.get()); 77 string output_file_contents_txt; 78 Export(toco_flags, *model, &output_file_contents_txt); 79 80 // Convert arguments back to byte (py3) or str (py2) 81 return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), 82 output_file_contents_txt.size()); 83} 84 85} // namespace toco 86