1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16/* Wrap trt_conversion */
17%{
18#define SWIG_FILE_WITH_INIT
19%}
20%include "std_pair.i"
21%include "tensorflow/python/platform/base.i"
22
23%{
24PyObject* pair_helper(std::pair<string, string>* in) {
25  PyObject *first(nullptr), *second(nullptr), *tuple(nullptr);
26  first = PyBytes_FromStringAndSize(in->first.data(), in->first.length());
27  if (!first) {
28    if (!PyErr_Occurred()) {
29      PyErr_SetString(PyExc_TypeError, "Pair conversion first argument failed");
30    }
31    return NULL;
32  }
33  second = PyBytes_FromStringAndSize(in->second.data(), in->second.length());
34  if (!second) {
35    if (!PyErr_Occurred()) {
36      PyErr_SetString(PyExc_TypeError,
37                      "Pair conversion second argument failed");
38    }
39    return NULL;
40  }
41  tuple = Py_BuildValue("(OO)", first, second);
42  if (!tuple) {
43    if (!PyErr_Occurred()) {
44      PyErr_SetString(PyExc_TypeError,
45                      "Tuple creation from pair<string,string> failed!");
46    }
47    return NULL;
48  }
49  return tuple;
50}
51%}
52%typemap(out) std::pair<string, string> {
53  PyObject *tuple = pair_helper(&$1);
54  if (!tuple) SWIG_fail;
55  $result = tuple;
56}
57%{
58#include "tensorflow/core/lib/core/errors.h"
59#include "tensorflow/core/lib/core/status.h"
60#include "tensorflow/core/util/stat_summarizer.h"
61#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
62%}
63
64%ignoreall
65%unignore tensorflow;
66%unignore trt_convert;
67
68%{
69std::pair<string, string> trt_convert(
70    string graph_def_string,  // The serialized GraphDef string.
71    std::vector<string> output_names,
72    size_t max_batch_size,
73    size_t max_workspace_size_bytes
74    // Unfortunately we can't use TF_Status here since it
75    // is in c/c_api and brings in a lot of other libraries
76    // which in turn declare ops. These ops are included
77    // statically in our library and cause an abort when
78    // module is loaded due to double registration
79    // until Tensorflow properly exposes these headers
80    // we have to work around this by returning a string
81    // and converting it to exception on python side.
82    //,TF_Status* out_status) {
83) {
84#if GOOGLE_CUDA && GOOGLE_TENSORRT
85  string out_status;
86
87  tensorflow::GraphDef graph_def;
88  if (!graph_def.ParseFromString(graph_def_string)) {
89    out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
90    return std::pair<string, string>{out_status, ""};
91  }
92
93  if (!output_names.size()) {
94    out_status = "InvalidArgument;Size of the output_names vector is 0";
95    return std::pair<string, string>{out_status, ""};
96    // return "";
97  }
98  tensorflow::GraphDef outGraph;
99  tensorflow::Status conversion_status =
100      tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
101          graph_def, output_names, max_batch_size, max_workspace_size_bytes,
102          &outGraph);
103  if (!conversion_status.ok()) {
104    auto retCode = (int)conversion_status.code();
105    char buff[2000];
106    snprintf(buff, 2000, "%d;%s", retCode,
107             conversion_status.error_message().c_str());
108    out_status = buff;
109    return std::pair<string, string>{out_status, ""};
110  }
111  string result;
112  if (!outGraph.SerializeToString(&result)) {
113    out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
114    return std::pair<string, string>{out_status, ""};
115  }
116  out_status = "OK;All good!";
117  return std::pair<string, string>{out_status, result};
118#else
119  // Returns FAILED_PRECONDITION.
120  return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
121#endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
122}
123%}
124
125std::pair<string, string> trt_convert(string graph_def_string,
126                                      std::vector<string> output_names,
127                                      size_t max_batch_size,
128                                      size_t max_workspace_size_bytes);
129
130
131%unignoreall
132