1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16/* Wrap trt_conversion */ 17%{ 18#define SWIG_FILE_WITH_INIT 19%} 20%include "std_pair.i" 21%include "tensorflow/python/platform/base.i" 22 23%{ 24PyObject* pair_helper(std::pair<string, string>* in) { 25 PyObject *first(nullptr), *second(nullptr), *tuple(nullptr); 26 first = PyBytes_FromStringAndSize(in->first.data(), in->first.length()); 27 if (!first) { 28 if (!PyErr_Occurred()) { 29 PyErr_SetString(PyExc_TypeError, "Pair conversion first argument failed"); 30 } 31 return NULL; 32 } 33 second = PyBytes_FromStringAndSize(in->second.data(), in->second.length()); 34 if (!second) { 35 if (!PyErr_Occurred()) { 36 PyErr_SetString(PyExc_TypeError, 37 "Pair conversion second argument failed"); 38 } 39 return NULL; 40 } 41 tuple = Py_BuildValue("(OO)", first, second); 42 if (!tuple) { 43 if (!PyErr_Occurred()) { 44 PyErr_SetString(PyExc_TypeError, 45 "Tuple creation from pair<string,string> failed!"); 46 } 47 return NULL; 48 } 49 return tuple; 50} 51%} 52%typemap(out) std::pair<string, string> { 53 PyObject *tuple = pair_helper(&$1); 54 if (!tuple) SWIG_fail; 55 $result = tuple; 56} 57%{ 58#include "tensorflow/core/lib/core/errors.h" 59#include "tensorflow/core/lib/core/status.h" 60#include "tensorflow/core/util/stat_summarizer.h" 61#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" 62%} 63 64%ignoreall 65%unignore tensorflow; 66%unignore trt_convert; 67 68%{ 69std::pair<string, string> trt_convert( 70 string graph_def_string, // The serialized GraphDef string. 71 std::vector<string> output_names, 72 size_t max_batch_size, 73 size_t max_workspace_size_bytes 74 // Unfortunately we can't use TF_Status here since it 75 // is in c/c_api and brings in a lot of other libraries 76 // which in turn declare ops. These ops are included 77 // statically in our library and cause an abort when 78 // module is loaded due to double registration 79 // until Tensorflow properly exposes these headers 80 // we have to work around this by returning a string 81 // and converting it to exception on python side. 82 //,TF_Status* out_status) { 83) { 84#if GOOGLE_CUDA && GOOGLE_TENSORRT 85 string out_status; 86 87 tensorflow::GraphDef graph_def; 88 if (!graph_def.ParseFromString(graph_def_string)) { 89 out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; 90 return std::pair<string, string>{out_status, ""}; 91 } 92 93 if (!output_names.size()) { 94 out_status = "InvalidArgument;Size of the output_names vector is 0"; 95 return std::pair<string, string>{out_status, ""}; 96 // return ""; 97 } 98 tensorflow::GraphDef outGraph; 99 tensorflow::Status conversion_status = 100 tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( 101 graph_def, output_names, max_batch_size, max_workspace_size_bytes, 102 &outGraph); 103 if (!conversion_status.ok()) { 104 auto retCode = (int)conversion_status.code(); 105 char buff[2000]; 106 snprintf(buff, 2000, "%d;%s", retCode, 107 conversion_status.error_message().c_str()); 108 out_status = buff; 109 return std::pair<string, string>{out_status, ""}; 110 } 111 string result; 112 if (!outGraph.SerializeToString(&result)) { 113 out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; 114 return std::pair<string, string>{out_status, ""}; 115 } 116 out_status = "OK;All good!"; 117 return std::pair<string, string>{out_status, result}; 118#else 119 // Returns FAILED_PRECONDITION. 120 return std::pair<string, string>{"9;TensorRT is not enabled!", ""}; 121#endif // GOOGLE_CUDA && GOOGLE_TENSORRT 122} 123%} 124 125std::pair<string, string> trt_convert(string graph_def_string, 126 std::vector<string> output_names, 127 size_t max_batch_size, 128 size_t max_workspace_size_bytes); 129 130 131%unignoreall 132