16908cc233c679b8fe61d99a30d3828362caf47beSami Kama/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
3825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaLicensed under the Apache License, Version 2.0 (the "License");
4825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamayou may not use this file except in compliance with the License.
5825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaYou may obtain a copy of the License at
6825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
7825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    http://www.apache.org/licenses/LICENSE-2.0
8825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
9825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaUnless required by applicable law or agreed to in writing, software
10825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamadistributed under the License is distributed on an "AS IS" BASIS,
11825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaSee the License for the specific language governing permissions and
13825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamalimitations under the License.
14825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama==============================================================================*/
15825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
16d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
17d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney
18825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <algorithm>
19825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <list>
20825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <map>
21825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <memory>
22825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <set>
23825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <unordered_map>
24825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <utility>
25825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <vector>
26825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
27825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/framework/node_def_builder.h"
28cf4d066720bbf9c3c79aa62d9b1057939af5ff63gracehoney#include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
29bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/framework/types.h"
30825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/algorithm.h"
31825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/graph.h"
32825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/graph_constructor.h"
33825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/lib/core/errors.h"
34825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/lib/core/status.h"
35bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/lib/strings/strcat.h"
36825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/platform/logging.h"
37bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/platform/tensor_coding.h"
38bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/platform/types.h"
39825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
40ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#if GOOGLE_CUDA
41ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#if GOOGLE_TENSORRT
42ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
438e03944589542bd64559d68989bca4a4705eed93gracehoney#include "tensorrt/include/NvInfer.h"
448e03944589542bd64559d68989bca4a4705eed93gracehoney
45d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney//  Check if the types are equal. Cast to int first so that failure log message
46d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney//  would work!
47825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
488e03944589542bd64559d68989bca4a4705eed93gracehoney
496908cc233c679b8fe61d99a30d3828362caf47beSami Kamanamespace tensorflow {
50825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace tensorrt {
51825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace convert {
52825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
53825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace {
54825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
55bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
56bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                       nvinfer1::DataType* trt_dtype) {
57825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (tf_dtype) {
58825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT:
59825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *trt_dtype = nvinfer1::DataType::kFLOAT;
60825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
61825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_INT8:
62825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *trt_dtype = nvinfer1::DataType::kINT8;
63825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
64825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_HALF:
65825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *trt_dtype = nvinfer1::DataType::kHALF;
66825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
67825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
68825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument("Unsupported data type");
69825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
70825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
71825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
72825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
73bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
74825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims dims;
75825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  dims.nbDims = tensor.dims();
76825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int i = 0; i < dims.nbDims; i++) {
77825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    dims.d[i] = tensor.dim_size(i);
78825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
79825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return dims;
80825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
81825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
82bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline int64_t GetShapeSize(nvinfer1::Dims shape) {
83825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Returns total number of elements in shape
84825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int64_t count = 1;
85825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int d = 0; d < shape.nbDims; ++d) {
86825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    count *= shape.d[d];
87825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
88825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return count;
89825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
90825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
91bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystatic std::vector<std::pair<int, int>> CreateSamePadding(
92bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
93bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const std::vector<int64_t>& input_dims) {
94bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<std::pair<int, int>> padding(input_dims.size());
95bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  CHECK_EQ((size_t)stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
96825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
97bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  for (size_t i = 0; i < input_dims.size(); ++i) {
98bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Formula to calculate the padding
99bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
100bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            input_dims[i];
101825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    p = (p > 0) ? p : 0;
102825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
103bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Right precedence padding, like in TensorFlow
104825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int left = p / 2;
105825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int right = p - left;
106825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
107f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
108bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
109f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << "kernel: " << kernel.d[i];
110825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    padding[i] = {left, right};
111825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
112825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return padding;
113825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
114825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
115825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TRT_ShapedWeights {
116825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
117bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights(tensorflow::DataType type, const void* values,
118b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie                    nvinfer1::Dims shape)
119b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie      : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) {
120bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Note: this->shape.type[] is not used
121bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  }
122bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
123bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  explicit TRT_ShapedWeights(tensorflow::DataType type)
124b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie      : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
125bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
126cfa374cefe132be886c26a374c51454177c68868gracehoney  TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
127cfa374cefe132be886c26a374c51454177c68868gracehoney      : shape_(rhs.shape_),
128cfa374cefe132be886c26a374c51454177c68868gracehoney        type_(rhs.type_),
129cfa374cefe132be886c26a374c51454177c68868gracehoney        values_(rhs.values_),
130cd63c718be123324b6c39e0f8fbe453319799746Jie        empty_weight_flag_(rhs.empty_weight_flag_) {}
131bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
132825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int64_t count() const {
133825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int64_t c = 1;
134825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
135825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return c;
136825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
137bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
138bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  nvinfer1::Weights GetWeightsForTRT() const {
139825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
140bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_CHECK_OK(ConvertDType(type_, &trt_type));
141cd63c718be123324b6c39e0f8fbe453319799746Jie    if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
142825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
143825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // Note: this->shape.type[] is not used
144cfa374cefe132be886c26a374c51454177c68868gracehoney    return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
145cfa374cefe132be886c26a374c51454177c68868gracehoney  }
146cfa374cefe132be886c26a374c51454177c68868gracehoney
147b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie  const void* GetValues() const { return values_; }
148cfa374cefe132be886c26a374c51454177c68868gracehoney
149b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie  void SetValues(const void* values) { values_ = values; }
150bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
151825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  size_t size_bytes() const {
15275adab6104362d71ce28b0269bf31fd30471b1b6Jie    int type_size = tensorflow::DataTypeSize(this->type_);
15375adab6104362d71ce28b0269bf31fd30471b1b6Jie    return this->count() * type_size;
154825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
155bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
156bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Default converter
157bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
158bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
159bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  nvinfer1::Dims shape_;
160bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  tensorflow::DataType type_;
161cfa374cefe132be886c26a374c51454177c68868gracehoney
162cfa374cefe132be886c26a374c51454177c68868gracehoney private:
163bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  const void* values_;
164cd63c718be123324b6c39e0f8fbe453319799746Jie  bool empty_weight_flag_;
165825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
166825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
167825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TRT_TensorOrWeights {
168825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
169825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
170cd63c718be123324b6c39e0f8fbe453319799746Jie      : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
171cfa374cefe132be886c26a374c51454177c68868gracehoney  explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
172cd63c718be123324b6c39e0f8fbe453319799746Jie      : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
173cfa374cefe132be886c26a374c51454177c68868gracehoney  TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
1741726dfc979de175eb093b3ee88907fbdc238ce79Jie      : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
175bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  ~TRT_TensorOrWeights() {}
176bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
177cd63c718be123324b6c39e0f8fbe453319799746Jie  bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
178cd63c718be123324b6c39e0f8fbe453319799746Jie  bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
179bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
180825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* tensor() {
181cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_tensor(), true);
182cd63c718be123324b6c39e0f8fbe453319799746Jie    return tensor_;
183825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
184cfa374cefe132be886c26a374c51454177c68868gracehoney  const nvinfer1::ITensor* tensor() const {
185cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_tensor(), true);
186cd63c718be123324b6c39e0f8fbe453319799746Jie    return tensor_;
187825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
188825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights& weights() {
189cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_weights(), true);
190cd63c718be123324b6c39e0f8fbe453319799746Jie    return weights_;
191825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
192bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  const TRT_ShapedWeights& weights() const {
193cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_weights(), true);
194cd63c718be123324b6c39e0f8fbe453319799746Jie    return weights_;
195825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
196825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims shape() const {
197cfa374cefe132be886c26a374c51454177c68868gracehoney    if (is_tensor()) {
198cfa374cefe132be886c26a374c51454177c68868gracehoney      return tensor()->getDimensions();
199825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else {
200cfa374cefe132be886c26a374c51454177c68868gracehoney      return weights().shape_;
201825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
202825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
203825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
204bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney private:
205cd63c718be123324b6c39e0f8fbe453319799746Jie  nvinfer1::ITensor* tensor_;
206cd63c718be123324b6c39e0f8fbe453319799746Jie  TRT_ShapedWeights weights_;
207cd63c718be123324b6c39e0f8fbe453319799746Jie  enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
208825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
209825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
210825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TFAttrs {
211825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
212bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
213bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (const auto& attr : tf_node.attr()) {
214cd63c718be123324b6c39e0f8fbe453319799746Jie      attrs_.insert({attr.first, &attr.second});
215825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
216825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
217cd63c718be123324b6c39e0f8fbe453319799746Jie  bool count(string key) const { return attrs_.count(key); }
218bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  tensorflow::AttrValue const* at(string key) const {
219cd63c718be123324b6c39e0f8fbe453319799746Jie    if (!attrs_.count(key)) {
2208e03944589542bd64559d68989bca4a4705eed93gracehoney      LOG(FATAL) << "Attribute not found: " << key;
221825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
222cd63c718be123324b6c39e0f8fbe453319799746Jie    return attrs_.at(key);
223825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
224825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
225bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  T get(string key) const;
226825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
227bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  T get(string key, const T& default_value) const {
228cd63c718be123324b6c39e0f8fbe453319799746Jie    return attrs_.count(key) ? this->get<T>(key) : default_value;
229825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
230bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
231bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney private:
232bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
233cd63c718be123324b6c39e0f8fbe453319799746Jie  AttrMap attrs_;
234825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
235825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
236825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
237bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystring TFAttrs::get<string>(string key) const {
238825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return this->at(key)->s();
239825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
240bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
241825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
242bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystd::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
243825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto attr = this->at(key)->list().i();
244825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return std::vector<int>(attr.begin(), attr.end());
245825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
246bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
247825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
248bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneynvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
249825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto values = this->get<std::vector<int>>(key);
250825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims dims;
251825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  dims.nbDims = values.size();
252825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::copy(values.begin(), values.end(), dims.d);
253825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Note: No dimension type information is included
254825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return dims;
255825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
2566908cc233c679b8fe61d99a30d3828362caf47beSami Kama
257825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
258bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneynvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
259825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
260bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
261825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return trt_dtype;
262825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
2636908cc233c679b8fe61d99a30d3828362caf47beSami Kama
264825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
265bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
266825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return this->at(key)->type();
267825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
268825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
269825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <typename T>
270cfa374cefe132be886c26a374c51454177c68868gracehoneyvoid Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
271825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama              nvinfer1::DimsNCHW istrides, T* odata,
272825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama              nvinfer1::DimsNCHW ostrides) {
273825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int n = 0; n < shape.n(); ++n) {
274825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int c = 0; c < shape.c(); ++c) {
275825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      for (int h = 0; h < shape.h(); ++h) {
276825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        for (int w = 0; w < shape.w(); ++w) {
277825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
278825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
279825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                          h * istrides.h() + w * istrides.w()];
280825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        }
281825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
282825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
283825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
284825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
285825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
286bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyvoid ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
287bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                       TRT_ShapedWeights* oweights) {
288825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights.type_, oweights->type_);
289825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
290825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int r = iweights.shape_.d[0];
291825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int s = iweights.shape_.d[1];
292825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int c = iweights.shape_.d[2];
293825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int k = iweights.shape_.d[3];
294825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[0] = k;
295825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[1] = c;
296825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[2] = r;
297825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[3] = s;
298825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
299825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
300825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (iweights.type_) {
301825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT:
302cfa374cefe132be886c26a374c51454177c68868gracehoney      Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
303cfa374cefe132be886c26a374c51454177c68868gracehoney               istrides,
304cfa374cefe132be886c26a374c51454177c68868gracehoney               static_cast<float*>(const_cast<void*>(oweights->GetValues())),
305cfa374cefe132be886c26a374c51454177c68868gracehoney               ostrides);
306825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
307825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
308825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
309825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
310825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
311825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
312825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamastruct InferDeleter {
313825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
314825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  void operator()(T* obj) const {
315825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (obj) {
316825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      obj->destroy();
317825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
318825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
319825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
320825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
321825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <typename T>
322825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamainline std::shared_ptr<T> infer_object(T* obj) {
323825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return std::shared_ptr<T>(obj, InferDeleter());
324825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
325825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
326825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// Logger for GIE info/warning/errors
327825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass Converter;
328825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
329825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamausing OpConverter =
330bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
331825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     std::vector<TRT_TensorOrWeights> const&,
332825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     std::vector<TRT_TensorOrWeights>*)>;
333825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
334825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass Converter {
335cd63c718be123324b6c39e0f8fbe453319799746Jie  std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
336cd63c718be123324b6c39e0f8fbe453319799746Jie  std::unordered_map<string, OpConverter> op_registry_;
337cd63c718be123324b6c39e0f8fbe453319799746Jie  nvinfer1::INetworkDefinition* trt_network_;
338cd63c718be123324b6c39e0f8fbe453319799746Jie  std::list<std::vector<uint8_t>> temp_bufs_;
339825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
340825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  void register_op_converters();
341825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
342825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<TRT_TensorOrWeights> get_inputs(
343bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      const tensorflow::NodeDef& node_def) {
344825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> inputs;
345bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (const auto& input_name : node_def.input()) {
346bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Retrieve input: " << input_name;
347cd63c718be123324b6c39e0f8fbe453319799746Jie      inputs.push_back(trt_tensors_.at(input_name));
348825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
349825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return inputs;
350825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
351825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
352825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
353825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  explicit Converter(nvinfer1::INetworkDefinition* trt_network)
354cd63c718be123324b6c39e0f8fbe453319799746Jie      : trt_network_(trt_network) {
355825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    this->register_op_converters();
356825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
357825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
358825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
359825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     nvinfer1::Dims shape) {
360825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TRT_ShapedWeights weights(type, nullptr, shape);
36175adab6104362d71ce28b0269bf31fd30471b1b6Jie    // TODO(jie): check weights size_bytes. 0 means type error
362cd63c718be123324b6c39e0f8fbe453319799746Jie    temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
363cd63c718be123324b6c39e0f8fbe453319799746Jie    weights.SetValues(temp_bufs_.back().data());
364825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return weights;
365825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
366825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
367bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
368825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return this->get_temp_weights(weights.type_, weights.shape_);
369825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
370825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
371bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
372825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
373bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    string op = node_def.op();
374cd63c718be123324b6c39e0f8fbe453319799746Jie    if (!op_registry_.count(op)) {
375825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::Unimplemented(
376bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "No converter registered for op: " + op);
377825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
378cd63c718be123324b6c39e0f8fbe453319799746Jie    OpConverter op_converter = op_registry_.at(op);
379825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> outputs;
380825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
381825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (size_t i = 0; i < outputs.size(); ++i) {
382825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      TRT_TensorOrWeights output = outputs.at(i);
383825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      // TODO(jie): tf protobuf seems to be omitting the :0 suffix
384bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      string output_name = node_def.name();
385825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (i != 0) output_name = output_name + ":" + std::to_string(i);
386825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (output.is_tensor()) {
387825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        output.tensor()->setName(output_name.c_str());
388825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
389bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Write out tensor: " << output_name;
390cd63c718be123324b6c39e0f8fbe453319799746Jie      if (!trt_tensors_.insert({output_name, output}).second) {
391825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return tensorflow::errors::AlreadyExists(
392bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            "Output tensor already exists for op: " + op);
393825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
394825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
395825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::Status::OK();
396825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
397825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
398cd63c718be123324b6c39e0f8fbe453319799746Jie  nvinfer1::INetworkDefinition* network() { return trt_network_; }
399825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
400bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_TensorOrWeights get_tensor(string name) {
401cd63c718be123324b6c39e0f8fbe453319799746Jie    if (!trt_tensors_.count(name)) {
402825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return TRT_TensorOrWeights(nullptr);
403825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
404cd63c718be123324b6c39e0f8fbe453319799746Jie    return trt_tensors_.at(name);
405825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
406825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
407bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
408cd63c718be123324b6c39e0f8fbe453319799746Jie    return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
409825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
410825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
411bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
412825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     std::vector<int> order) {
413825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto dims = input_tensor->getDimensions();
414825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
415825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): change the return to status and properly exit
416825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (order.size() - 1 != size_t(dims.nbDims))
417bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      LOG(ERROR) << "Dimension does not match, fail gracefully";
418825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
419825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
420825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::Permutation permutation;
421825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int32_t i = 0; i < dims.nbDims; ++i) {
422825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      permutation.order[i] = order[i + 1] - 1;
423825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
424825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    layer->setFirstTranspose(permutation);
425825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
426bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    nvinfer1::Dims reshape_dims;
427bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    reshape_dims.nbDims = dims.nbDims;
428bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
429bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      reshape_dims.d[i] = 0;
430bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      reshape_dims.type[i] = dims.type[i];
431825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
432bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    layer->setReshapeDimensions(reshape_dims);
433825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return layer->getOutput(0);
434825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
435825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
436825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
4376908cc233c679b8fe61d99a30d3828362caf47beSami Kama// ****************************************************************************
438e01844e65e0dbd2682a894946bec7f072d36fa27Jie// Constant folding functions
439e01844e65e0dbd2682a894946bec7f072d36fa27Jie// TODO(jie): once optimizer kicks in, we should have done constant folding
440e01844e65e0dbd2682a894946bec7f072d36fa27Jie// there.
4416908cc233c679b8fe61d99a30d3828362caf47beSami Kama//*****************************************************************************/
442825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamastruct LambdaFactory {
443825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
444825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  OP_CATEGORY op;
445825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
446825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
447825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T)> unary() {
448825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
449825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::RSQRT: {
450f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama        VLOG(2) << "RSQRT GETS DONE";
451825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T t) -> T { return 1.0 / std::sqrt(t); };
452825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
453825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::NEG:
454825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T t) -> T { return -t; };
455825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
456bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
457825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return nullptr;
458825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
459825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
460825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
461825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
462825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T, T)> binary() {
463825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
464825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::ADD:
465825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T l, T r) -> T { return l + r; };
466825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::SUB:
467825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T l, T r) -> T { return l - r; };
468825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::MUL:
469825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T l, T r) -> T { return l * r; };
470825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
471bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
472825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
473825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return [](T l, T r) -> T {
474825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "Unsupported op type ";
475825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return l;
476825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    };
477825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
478825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
479825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
480825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T)> broadcast_r(T val) {
481f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "LAMBDA VAL : " << val;
482825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
483825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::ADD:
484825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
485f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
486825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return l + val;
487825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
488bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // Return [val](T l)-> T {return l+val;};
489825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::SUB:
490825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
491f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
492825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return l - val;
493825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
494825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::MUL:
495825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
496f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
497825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return l * val;
498825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
499825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
500bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
501825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
502825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return [val](T l) -> T {
503825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "Unsupported op type ";
504825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return l;
505825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    };
506825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
507825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
508825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
509825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T)> broadcast_l(T val) {
510f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "LAMBDA VAL : " << val;
511825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
512825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::ADD:
513825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
514f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
515825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return val + l;
516825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
517825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::SUB:
518825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
519f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
520825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return val - l;
521825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
522825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::MUL:
523825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
524f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
525825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return val * l;
526825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
527825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
528bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
529825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
530825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return [val](T l) -> T {
531825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "Unsupported op type ";
532825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return l;
533825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    };
534825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
535825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
536825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
537bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
538825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                TRT_ShapedWeights* oweights,
539825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                LambdaFactory unary_op) {
540825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights.type_, oweights->type_);
541825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (iweights.type_) {
542825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT: {
543cfa374cefe132be886c26a374c51454177c68868gracehoney      auto inp = static_cast<float const*>(iweights.GetValues());
544cfa374cefe132be886c26a374c51454177c68868gracehoney      auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
545825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
546825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
547825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
548825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
549bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      return tensorflow::errors::Unimplemented(
550bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Data type not supported: " +
551bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          tensorflow::DataTypeString(iweights.type_));
552825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
553825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
554825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
555825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
556bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
557bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const TRT_ShapedWeights& iweights_r,
558825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 TRT_ShapedWeights* oweights,
559825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 LambdaFactory binary_op) {
560bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Assume iweights_l.type == iweight_r.type
561825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights_l.type_, oweights->type_);
562825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights_r.type_, oweights->type_);
563f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(2) << "SANITY CHECK!";
564825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
565825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (iweights_l.type_) {
566825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT: {
567cfa374cefe132be886c26a374c51454177c68868gracehoney      auto inp_l = static_cast<const float*>(iweights_l.GetValues());
568cfa374cefe132be886c26a374c51454177c68868gracehoney      auto inp_r = static_cast<const float*>(iweights_r.GetValues());
569cfa374cefe132be886c26a374c51454177c68868gracehoney      auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
570825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
571825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (iweights_l.count() != iweights_r.count()) {
572bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        // We only supports broadcast of RankZero
573825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        if (iweights_l.count() == 1) {
574f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "I bet it is not working!" << (*inp_l);
575825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          std::transform(inp_r, inp_r + iweights_r.count(), oup,
576825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                         binary_op.broadcast_l<float>(*inp_l));
577825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        } else if (iweights_r.count() == 1) {
578f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "I bet it is not working!" << (*inp_r);
579825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          std::transform(inp_l, inp_l + iweights_l.count(), oup,
580825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                         binary_op.broadcast_r<float>(*inp_r));
581825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        } else {
582825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return tensorflow::errors::Unimplemented(
583825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama              "Binary op with non-rankZero broadcast not supported");
584825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        }
585825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      } else {
586825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
587825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                       binary_op.binary<float>());
588825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
589825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
590825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
591825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
592bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      return tensorflow::errors::Unimplemented(
593bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Data type not supported: " +
594bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          tensorflow::DataTypeString(iweights_l.type_));
595825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
596825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
597825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
598825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
599825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
600825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConstantFoldUnary(
601bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
602825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
603825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
604825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_input = inputs.at(0).weights();
605825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
606bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Allocate output weights
607825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
608825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
609825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
610bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
611bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
612bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
613825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(weights_input.type_,
614825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama           TFAttrs(node_def).get<tensorflow::DataType>("T"));
615825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
616825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Maybe I should do a switch
617825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  LambdaFactory unary_op;
618825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Rsqrt") {
619bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Compute rsqrt
620825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
621825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
622bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // PAss the output
623825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (ret == tensorflow::Status::OK()) {
624825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      outputs->push_back(TRT_TensorOrWeights(weights_output));
625825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
626825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return ret;
627825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
628825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Binary op not supported: " +
629825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                             node_def.op());
630825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
631825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
632825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
633825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// TODO(jie,ben) broadcast is needed yet not implemented
634825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// Let's get the simple stuff working first. Maybe we should fall bakc to TF
635825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama//   approach for constant folding
636825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConstantFoldBinary(
637bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
638825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
639825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
640825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
641825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
642825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
643bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
644825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
645825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
646825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
647825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
648825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Binary op implicit broadcast not supported: " + node_def.op());
649825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
650825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): constant fold should really fall back to TF.
651bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  int nb_dims = weights_input_l.shape_.nbDims;
652825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims output_shape;
653bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  output_shape.nbDims = nb_dims;
654bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "nb_dims: " << nb_dims
655bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          << ", the other: " << weights_input_r.shape_.nbDims;
656bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  for (int i = 0; i < nb_dims; i++) {
657825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
658825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      output_shape.d[i] = weights_input_l.shape_.d[i];
659825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else if (weights_input_l.shape_.d[i] == 1 ||
660825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama               weights_input_r.shape_.d[i] == 1) {
661825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      output_shape.d[i] =
662825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
663825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else {
664825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::Unimplemented(
665825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          "Binary op with incompatible shape at, " + node_def.op());
666825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
667f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "left: " << weights_input_l.shape_.d[i]
668f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << "right: " << weights_input_r.shape_.d[i]
669f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << "output: " << output_shape.d[i];
670825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
671825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
672825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
673bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
674825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
675bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
676825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
677825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
678bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Allocate output weights
679825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
680825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
681825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Maybe I should do a switch
682825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  LambdaFactory binary_op;
683825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Sub") {
684825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
685825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Mul") {
686825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
687825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Add") {
688825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
689825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
690825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Binary op not supported: " +
691825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                             node_def.op());
692825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
693825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
694825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                           binary_op);
695825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
696bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
697825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (ret == tensorflow::Status::OK()) {
698825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    outputs->push_back(TRT_TensorOrWeights(weights_output));
699825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
700825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
701825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return ret;
702825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
703825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
704bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney// TODO(jie): broadcast is needed yet not implemented.
705bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney// Only implemented channel wise for the time being
706825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status BinaryTensorOpWeight(
707bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
708825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
709825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
710825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
711bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
712bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
713825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
714bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
715825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
716bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  CHECK_EQ_TYPE(tensor->getType(), dtype);  // Cast to int for error messages
717825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType ttype;
718bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
719bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  CHECK_EQ_TYPE(ttype, dtype);  // Cast to int for error message
720825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
721bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check scale mode
722825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims_w = weights.shape_;
723825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims_t = tensor->getDimensions();
724825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
725bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Default to channel-wise
726825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
727825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
728825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (weights.count() == 1) {
729f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "UNIFORM";
730825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    scale_mode = nvinfer1::ScaleMode::kUNIFORM;
731825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
732bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // No broadcasting on Batch dimension;
733e01844e65e0dbd2682a894946bec7f072d36fa27Jie    assert(dims_w.d[0] == 1);
734825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
735bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Broadcasting on Channel dimension only allowed in kUNIFORM
736e01844e65e0dbd2682a894946bec7f072d36fa27Jie    assert(dims_w.d[1] == dims_t.d[0]);
737e01844e65e0dbd2682a894946bec7f072d36fa27Jie    assert(dims_w.nbDims == dims_t.nbDims);
738825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
739bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Default is element;
740e01844e65e0dbd2682a894946bec7f072d36fa27Jie    for (int i = 2; i < dims_w.nbDims; i++) {
741e01844e65e0dbd2682a894946bec7f072d36fa27Jie      if (dims_w.d[i] != dims_t.d[i - 1]) {
742825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        scale_mode = nvinfer1::ScaleMode::kCHANNEL;
743825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        break;
744825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
745825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
746825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
747825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
748e01844e65e0dbd2682a894946bec7f072d36fa27Jie      for (int i = 2; i < dims_w.nbDims; i++) {
749e01844e65e0dbd2682a894946bec7f072d36fa27Jie        if (dims_w.d[i] != 1)
750825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return tensorflow::errors::InvalidArgument(
751e01844e65e0dbd2682a894946bec7f072d36fa27Jie              "Weight shape not compatible at, " + node_def.name());
752825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
753825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
754825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
755825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
756bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Prepare weights
757bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights shift_weights(weights.type_);
758bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights scale_weights(weights.type_);
759bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights power_weights(weights.type_);
760825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
761825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Maybe I should do a switch
762825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Sub") {
763825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
764825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    LambdaFactory unary_op;
765825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
766bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
767bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    shift_weights = neg_weights;
768825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Mul") {
769bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    scale_weights = weights;
770825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Add") {
771bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    shift_weights = weights;
772825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
773825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Binary op not supported: " +
774825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                             node_def.op());
775825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
776825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
777825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
778bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
779bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      scale_weights, power_weights);
780825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
781825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
782825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
783bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
784825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
785825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
786825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
787825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
788825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status BinaryTensorOpTensor(
789bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
790825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
791825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
792bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
793bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Add", nvinfer1::ElementWiseOperation::kSUM},
794bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Mul", nvinfer1::ElementWiseOperation::kPROD},
795bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // {"max", nvinfer1::ElementWiseOperation::kMAX},
796bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // {"min", nvinfer1::ElementWiseOperation::kMIN},
797bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Sub", nvinfer1::ElementWiseOperation::kSUB},
798bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Div", nvinfer1::ElementWiseOperation::kDIV},
799bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  };
800825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
801825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
802bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
803825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
804bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
805825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
806825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
807bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
808825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ_TYPE(tensor_l->getType(), dtype);
809825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ_TYPE(tensor_r->getType(), dtype);
810825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto op_pair = ops.find(node_def.op());
811825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (op_pair == ops.end())
8121726dfc979de175eb093b3ee88907fbdc238ce79Jie    return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
8131726dfc979de175eb093b3ee88907fbdc238ce79Jie                                             " not supported at: " +
8141726dfc979de175eb093b3ee88907fbdc238ce79Jie                                             node_def.name());
815825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
816825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
817825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor_l),
818825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
819825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
820825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
821825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
822bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
823825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
824825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
825825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
826825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
827825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPlaceholder(
828bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
829825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
830825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
831f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(2) << "Placeholder should have been replace already";
832bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
833825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // OK this make sense since we are supposed to replace it with input
834825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
835825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
836825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
837825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
838825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  dims.nbDims--;
839825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
840825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
841825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output =
842825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
843825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!output) {
844825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument("Failed to create Input layer");
845825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
846825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output));
847825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
848825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
849825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
850825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertConv2D(Converter& ctx,
851bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const tensorflow::NodeDef& node_def,
852bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const std::vector<TRT_TensorOrWeights>& inputs,
853825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights>* outputs) {
854825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
855825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): handle NHWC/NCHW transpose;
856825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
857825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
858bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  ReorderRSCKToKCRS(weights_rsck, &weights);
859825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights biases(weights.type_);
860825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int noutput = weights.shape_.d[0];
861825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW kernel_size;
862825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  kernel_size.h() = weights.shape_.d[2];
863825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  kernel_size.w() = weights.shape_.d[3];
864825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
865825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
866825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int h_index = 2;
867825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int w_index = 3;
868bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  auto data_format = attrs.get<string>("data_format");
869825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
870bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
871825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 1, 2});
872825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    h_index = 1;
873825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    w_index = 2;
874825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it
875825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
876e01844e65e0dbd2682a894946bec7f072d36fa27Jie
877825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): stride. (NHWC/NCHW)
878825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tf_stride = attrs.get<std::vector<int>>("strides");
879825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
880825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
881825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tensor_dim = tensor->getDimensions();
882825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<std::pair<int, int>> padding;
883825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): padding.
884bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (attrs.get<string>("padding") == "SAME") {
885825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // This is NCHW tensor with no batch dimension.
886825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  1 -> h
887825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  2 -> w
888bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    padding = CreateSamePadding(
889e01844e65e0dbd2682a894946bec7f072d36fa27Jie        stride, kernel_size,
890e01844e65e0dbd2682a894946bec7f072d36fa27Jie        {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
891825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
892825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    padding = {{0, 0}, {0, 0}};
893825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
894825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
895825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (padding[0].first != padding[0].second ||
896825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      padding[1].first != padding[1].second) {
897825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): handle asymmetric padding
898bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
899f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << padding[1].first << padding[1].second;
90024e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie
90124e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    auto dim_before = tensor->getDimensions();
902f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
903f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << dim_before.d[2] << ", " << dim_before.d[3];
904bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    auto pad_layer = ctx.network()->addPadding(
905825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        *const_cast<nvinfer1::ITensor*>(tensor),
90624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].first, padding[1].first),
90724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].second, padding[1].second));
90824e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    padding = {{0, 0}, {0, 0}};
909bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = pad_layer->getOutput(0);
91024e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    auto dim_after = tensor->getDimensions();
911f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
912f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << dim_after.d[2] << ", " << dim_after.d[3];
913825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
914825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
915825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IConvolutionLayer* layer =
916825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
917825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                    noutput, kernel_size, weights, biases);
918825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
919825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setStride(stride);
920825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setPadding({padding[0].first, padding[1].first});
921825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setName(node_def.name().c_str());
922825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
923825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
92424e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie  auto dim_after = output_tensor->getDimensions();
925f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
926f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          << dim_after.d[2] << ", " << dim_after.d[3];
92724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie
928825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
929825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it back!
930bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
931825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
932f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
933825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
934825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
935825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
936825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
937825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
938825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPool(Converter& ctx,
939bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                               const tensorflow::NodeDef& node_def,
940825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                               std::vector<TRT_TensorOrWeights> const& inputs,
941825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                               std::vector<TRT_TensorOrWeights>* outputs) {
942825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
943825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
944825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
945825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int h_index = 2;
946825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int w_index = 3;
947bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  auto data_format = attrs.get<string>("data_format");
948825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
949825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    h_index = 1;
950825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    w_index = 2;
951bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
952825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 1, 2});
953825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
954f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
955825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
956825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::PoolingType type;
957825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): support other pooling type
958825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "MaxPool")
959825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    type = nvinfer1::PoolingType::kMAX;
960825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  else
961bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    return tensorflow::errors::Unimplemented("Only supports Max pool");
962825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
963825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): NCHW
964825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tf_stride = attrs.get<std::vector<int>>("strides");
965825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
966825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
967825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tf_kernel = attrs.get<std::vector<int>>("ksize");
968825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
969825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
970825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tensor_dim = tensor->getDimensions();
971825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<std::pair<int, int>> padding;
972825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): padding.
973bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (attrs.get<string>("padding") == "SAME") {
974825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // This is NCHW tensor with no batch dimension.
975825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  1 -> h
976825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  2 -> w
977bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    padding = CreateSamePadding(
978825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        stride, ksize,
979825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
980bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  } else if (attrs.get<string>("padding") == "VALID") {
981825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // No padding for valid padding here
982bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
983825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    padding = {{0, 0}, {0, 0}};
984825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
985825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
986825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Current MaxPool cannot support padding other than SAME");
987825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
988825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
989825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (padding[0].first != padding[0].second ||
990825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      padding[1].first != padding[1].second) {
991825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): handle asymmetric padding
992bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
993f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << padding[1].first << padding[1].second;
994bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    auto pad_layer = ctx.network()->addPadding(
995825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        *const_cast<nvinfer1::ITensor*>(tensor),
99624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].first, padding[1].first),
99724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].second, padding[1].second));
99824e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    padding = {{0, 0}, {0, 0}};
999bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = pad_layer->getOutput(0);
1000825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1001825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1002825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
1003825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
1004825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1005825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setStride(stride);
1006825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setPadding({padding[0].first, padding[1].first});
1007825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setName(node_def.name().c_str());
1008825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1009825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1010825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
1011825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it back!
1012bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
1013825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1014f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
1015825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1016825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1017825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1018825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1019825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1020825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertActivation(
1021bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
1022825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
1023825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
1024825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1025825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
1026825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
1027825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1028825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1029825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1030825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1031825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1032825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertScale(Converter& ctx,
1033bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                const tensorflow::NodeDef& node_def,
1034825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights> const& inputs,
1035825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights>* outputs) {
1036825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1037825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      !inputs.at(1).is_weights())
1038825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1039bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Only supports tensor op weight for now, at " + node_def.name());
1040bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Implement tensor binaryOp weight [channel wise] for now;
1041825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1042825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1043825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): handle NHWC/NCHW transpose;
1044825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights = inputs.at(1).weights();
1045825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights empty_weights(weights.type_);
1046825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1047825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1048825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1049bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Transpose NHWC
1050bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  auto data_format = attrs.get<string>("data_format");
1051825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
1052bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1053825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 1, 2});
1054825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it
1055825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1056f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
1057825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1058825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
1059825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
1060825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      weights, empty_weights, empty_weights);
1061825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1062825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1063825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
1064825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it back!
1065bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
1066825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1067f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
1068825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1069825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1070825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1071825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1072825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1073825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertConst(Converter& ctx,
1074bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                const tensorflow::NodeDef& node_def,
1075825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights> const& inputs,
1076825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights>* outputs) {
1077bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  const auto& weights_tensor = node_def.attr().at("value").tensor();
1078825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1079bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
1080825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1081cfa374cefe132be886c26a374c51454177c68868gracehoney  const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
1082825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1083bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Create shaped weights as output
1084825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::Tensor tensor;
1085825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!tensor.FromProto(weights_tensor))
1086bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
1087825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                        node_def.name());
1088825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1089825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights(dtype);
1090825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!weights_tensor.float_val().empty()) {
1091f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "SCALAR!!!" << node_def.name();
1092825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::Dims scalar_shape;
1093825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (tensor.dims() > 0) {
1094bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Dimensions: " << tensor.dims();
1095825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
1096bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                  GetTensorShape(tensor));
1097825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else {
1098bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Dimensions: " << tensor.dims();
1099825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scalar_shape.nbDims = 1;
1100825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scalar_shape.d[0] = 1;
1101825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
1102825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
1103825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        scalar_shape.d[i] = 0;
1104825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
1105825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
1106825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
1107825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                  scalar_shape);
1108825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1109825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (!weights_tensor.tensor_content().empty()) {
1110f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "TENSOR!!!" << node_def.name();
1111bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const auto& content = weights_tensor.tensor_content();
1112e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie
1113e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie    weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor));
1114e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie    if (content.size() > 0) {
1115e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie      const int dtype_size = tensorflow::DataTypeSize(dtype);
1116e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie      CHECK_EQ(0, content.size() % dtype_size)
1117e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie          << "Tensor content size (" << content.size()
1118e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie          << ") is not a multiple of " << dtype_size;
11191726dfc979de175eb093b3ee88907fbdc238ce79Jie      port::CopyToArray(
11201726dfc979de175eb093b3ee88907fbdc238ce79Jie          content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
1121e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie    }
1122825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1123825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1124bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Not supported constant type, at " + node_def.name());
1125825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1126bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
1127825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(weights));
1128825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1129825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1130825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1131825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertIdentity(
1132bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
1133825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
1134825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
1135825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(inputs.at(0));
1136825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1137825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1138825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1139825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertBinary(Converter& ctx,
1140bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const tensorflow::NodeDef& node_def,
1141825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights> const& inputs,
1142825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights>* outputs) {
1143825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2)
1144825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::FailedPrecondition(
1145825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Binary ops require two tensor input, at " + node_def.name());
1146825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1147825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
1148825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return ConstantFoldBinary(ctx, node_def, inputs, outputs);
1149825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1150825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
1151825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
1152825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                inputs.at(1).weights(), outputs);
1153825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1154825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
1155825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
1156825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                inputs.at(0).weights(), outputs);
1157825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1158825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
1159825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
1160825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                inputs.at(1).tensor(), outputs);
1161825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1162825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::errors::Unknown("Binary op input error, at " +
1163825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     node_def.name());
1164825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1165825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1166825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertUnary(Converter& ctx,
1167bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                const tensorflow::NodeDef& node_def,
1168825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights> const& inputs,
1169825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights>* outputs) {
1170825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 1)
1171825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::FailedPrecondition(
1172825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Unary ops require single tensor input, at " + node_def.name());
1173825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1174825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_weights())
1175825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return ConstantFoldUnary(ctx, node_def, inputs, outputs);
1176825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  else if (inputs.at(0).is_tensor())
1177825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1178825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Unary op for tensor not supported, at " + node_def.name());
1179825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1180825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::errors::Unknown("Binary op input error, at " +
1181825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     node_def.name());
1182825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1183825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1184825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertReduce(Converter& ctx,
1185bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const tensorflow::NodeDef& node_def,
1186825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights> const& inputs,
1187825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights>* outputs) {
1188825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1189825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      !inputs.at(1).is_weights())
1190825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1191825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Input expects tensor and weights, at" + node_def.name());
1192825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1193bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Implement tensor binaryOp weight [channel wise] for now;
1194825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1195825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims = tensor->getDimensions();
1196bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Restore implicit batch dimension
1197bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  int nb_dims = dims.nbDims + 1;
1198825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1199825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights index_list = inputs.at(1).weights();
1200825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1201825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1202bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // TODO(jie): handle data type.
1203bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Index type here is done through TF type, so I can leverage their
1204bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // EnumToDataType for my cast
1205825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto index_type = attrs.get<tensorflow::DataType>("Tidx");
1206825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1207825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Only expect to handle INT32 as attributes for now
1208825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (index_type != tensorflow::DataType::DT_INT32)
1209825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
1210825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto index_list_data =
1211cfa374cefe132be886c26a374c51454177c68868gracehoney      static_cast<int*>(const_cast<void*>(index_list.GetValues()));
1212825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1213bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Hack warning: have to fall back to pool layer since reduce is not in public
1214bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // TRT yet.
1215bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (nb_dims != 4)
1216825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1217825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "TRT only support reduce on 4 dimensional tensors, at" +
1218825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        node_def.name());
1219825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (index_list.count() > 2)
1220825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1221825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "TRT cannot support reduce on more than 2 dimensions, at" +
1222825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        node_def.name());
1223825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1224825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::set<int> idx_set;
1225bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // We cannot operate on Channel. permutation flag used to transpose tensor
1226825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int permuted_index = -1;
1227825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int i = 0; i < index_list.count(); i++) {
1228825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (index_list_data[i] == 0)
1229825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
1230825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                                 node_def.name());
1231825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (index_list_data[i] == 1) permuted_index = 1;
1232825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    idx_set.emplace(index_list_data[i]);
1233825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1234825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1235bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<int> permutation_order(nb_dims);
1236825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW pool_kernel;
1237825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (permuted_index == 1) {
1238bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (int i = 2; i < nb_dims; i++) {
1239825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (idx_set.count(i)) {
1240825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        permuted_index = i;
1241825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        break;
1242825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
1243825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1244bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
1245825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1246825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    permutation_order[permuted_index] = 1;
1247825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    permutation_order[1] = permuted_index;
1248825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1249bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Apply permutation before extracting dimension for pool_kernel
1250bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1251825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 permutation_order);
1252825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1253825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1254bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Apply permutation before extracting dimension for pool_kernel
1255825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
1256825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
1257825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1258825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor;
1259825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1260825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Mean") {
1261825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::IPoolingLayer* layer =
1262825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
1263825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                  nvinfer1::PoolingType::kAVERAGE, pool_kernel);
1264825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    output_tensor = layer->getOutput(0);
1265825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1266825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1267825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Op not supported " + node_def.op() + " , at " + node_def.name());
1268825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1269825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (permuted_index != -1) {
1270bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Apply permutation before extracting dimension for pool_kernel
1271bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(
1272825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
1273825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1274825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1275825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1276825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1277825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPad(Converter& ctx,
1278bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                              const tensorflow::NodeDef& node_def,
1279825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                              std::vector<TRT_TensorOrWeights> const& inputs,
1280825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                              std::vector<TRT_TensorOrWeights>* outputs) {
1281825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1282825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      !inputs.at(1).is_weights())
1283825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1284825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Input expects tensor and weights, at" + node_def.name());
1285825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1286bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Implement tensor binaryOp weight [channel wise] for now;
1287825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1288825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims = tensor->getDimensions();
1289bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Restore implicit batch dimension
1290bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  int nb_dims = dims.nbDims + 1;
1291825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1292825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights pads = inputs.at(1).weights();
1293825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1294825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1295bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Padding type here is done through TF type
1296825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  //   so I can leverage their EnumToDataType for my cast
1297825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
1298825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): handle data type conversion for TRT?
1299825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1300bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
1301825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1302825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Pad only supports explicit padding on 4 dimensional tensor, at " +
1303825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        node_def.name());
1304825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1305825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Only expect to handle INT32 as attributes for now
1306825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (padding_type != tensorflow::DataType::DT_INT32)
1307825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1308825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Tpaddings supports only DT_INT32");
1309cfa374cefe132be886c26a374c51454177c68868gracehoney  auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
1310825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1311825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<int32_t> pad_index;
1312bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  for (int i = 0; i < nb_dims; i++) {
1313825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
1314825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      pad_index.push_back(i);
1315825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1316825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1317bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // No padding at all, we should exit
1318825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index.size() == 0) {
1319825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    outputs->push_back(inputs.at(0));
1320825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::Status::OK();
1321825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1322825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1323bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Only supports padding on less than 2 axis GIE-2579
1324825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index.size() > 2)
1325825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1326825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Padding layer does not support padding on > 2");
1327825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1328bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Padding on batch dimension is not supported
1329825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index[0] == 0)
1330825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1331825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Padding layer does not support padding on batch dimension");
1332825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1333bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Not doing the legit thing here. ignoring padding on dim 1 and 3;
1334825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): implement pad as uff parser
1335825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
1336825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1337825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Padding layer does not support padding on dimension 1 and 3 yet");
1338825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1339825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  bool legit_pad = true;
1340825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW pre_padding(0, 0);
1341825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW post_padding(0, 0);
1342825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1343825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<int32_t> permuted_pad_index(pad_index);
1344825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index[0] == 1) {
1345825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    legit_pad = false;
1346bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1347825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 2, 1});
1348825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    permuted_pad_index[0] = 3;
1349825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1350825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1351825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (size_t i = 0; i < pad_index.size(); i++) {
1352825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int index = pad_index[i];
1353825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (permuted_pad_index[i] == 2) {
1354825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      pre_padding.h() = pad_data[index * 2];
1355825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      post_padding.h() = pad_data[index * 2 + 1];
1356825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else if (permuted_pad_index[i] == 3) {
1357825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      pre_padding.w() = pad_data[index * 2];
1358825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      post_padding.w() = pad_data[index * 2 + 1];
1359825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1360825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1361825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1362825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
1363825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
1364825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1365825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1366825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!legit_pad)
1367bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(
1368825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
1369825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1370825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1371825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1372825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1373825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1374825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamavoid Converter::register_op_converters() {
1375825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // vgg_16 slim implementation
1376cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Placeholder"] = ConvertPlaceholder;
1377cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Conv2D"] = ConvertConv2D;
1378cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Relu"] = ConvertActivation;
1379cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["MaxPool"] = ConvertPool;
1380825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // This could be really handled as ConvertBinary
1381cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["BiasAdd"] = ConvertScale;
1382cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Const"] = ConvertConst;
1383cd63c718be123324b6c39e0f8fbe453319799746Jie  // op_registry_["MatMul"] = ConvertFullyConnected;  // Not used in vgg
1384825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(ben,jie): this is a temp hack.
1385cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Identity"] = ConvertIdentity;  // Identity should be removed
1386cd63c718be123324b6c39e0f8fbe453319799746Jie  // op_registry_["AvgPool"] = ConvertPool;
1387825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1388825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // resnet_50_v1 slim implementation
1389cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Add"] = ConvertBinary;
1390cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Mul"] = ConvertBinary;
1391cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Sub"] = ConvertBinary;
1392cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Rsqrt"] = ConvertUnary;
1393cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Mean"] = ConvertReduce;
1394cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Pad"] = ConvertPad;
1395825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(ben,jie): Add more ops
1396825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1397825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1398825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}  // namespace
1399825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1400825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertSubGraphToTensorRTNodeDef(
1401825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
1402825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const std::vector<std::pair<int, int>>& input_inds,
1403825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
1404f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    size_t max_workspace_size_bytes,
140568e17d497497119c24ad506dac4e34e127cf836cJie    const tensorflow::grappler::GraphProperties& graph_properties,
1406825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::NodeDef* trt_node) {
1407825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Visit nodes in reverse topological order and construct the TRT network.
1408825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1409825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Toposort
1410825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::Node*> order_vec;
1411825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::GetPostOrder(graph, &order_vec);
1412825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Select just the subgraph
1413825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::list<tensorflow::Node*> order;
1414825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (tensorflow::Node* node : order_vec) {
1415825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (subgraph_node_ids.count(node->id())) {
1416bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // We want topological order to contstruct the
1417bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // network layer by layer
1418bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      order.push_front(node);
1419825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1420825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1421bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Topological order is needed to build TRT network
1422825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1423825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::tensorrt::Logger trt_logger;
1424825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1425825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
1426825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!trt_builder) {
1427825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Internal(
1428bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Failed to create TensorRT builder object");
1429825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1430825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1431825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto trt_network = infer_object(trt_builder->createNetwork());
1432825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!trt_network) {
1433825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Internal(
1434bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Failed to create TensorRT network object");
1435825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1436825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1437825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Build the network
1438825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  Converter converter(trt_network.get());
1439825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1440bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<string> input_names;
1441825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::DataType> input_dtypes;
1442825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (std::pair<int, int> const& input : input_inds) {
1443825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int node_id = input.first;
1444825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int output_idx = input.second;
1445825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::Node* node = graph.FindNodeId(node_id);
1446825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto node_name = node->name();
1447bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    input_names.push_back(node_name);  // Insert original node name without port
1448825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): alternative :)
144968e17d497497119c24ad506dac4e34e127cf836cJie    if (!graph_properties.HasOutputProperties(node_name))
1450bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      return tensorflow::errors::Internal("Failed to find input node: " +
1451825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                          node_name);
1452825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
145368e17d497497119c24ad506dac4e34e127cf836cJie    auto op_info_vec = graph_properties.GetOutputProperties(node_name);
145468e17d497497119c24ad506dac4e34e127cf836cJie    if (static_cast<int>(op_info_vec.size()) < output_idx)
1455825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::Internal(
1456bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Accessing output index of: " + std::to_string(output_idx) +
1457bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          ", at node: " + node_name + " with output entry from shape_map: " +
145868e17d497497119c24ad506dac4e34e127cf836cJie          std::to_string(op_info_vec.size()));
1459825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
146068e17d497497119c24ad506dac4e34e127cf836cJie    auto op_info = op_info_vec.at(output_idx);
1461825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
146268e17d497497119c24ad506dac4e34e127cf836cJie    tensorflow::DataType tf_dtype = op_info.dtype();
1463825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    input_dtypes.push_back(tf_dtype);
1464825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1465825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
1466bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
1467825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1468bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
1469f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << ", at node: " << node_name
1470bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            << " with output entry from shape_map: "
1471f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << std::to_string(op_info_vec.size());
147268e17d497497119c24ad506dac4e34e127cf836cJie
1473825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(ben,jie): update TRT input format/dimension
1474825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DimsCHW input_dim_psuedo_chw;
1475825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
1476825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
147768e17d497497119c24ad506dac4e34e127cf836cJie    for (int i = 1; i < op_info.shape().dim_size(); i++) {
1478f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama      VLOG(2) << "dimension: " << i
1479f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama              << " , size: " << op_info.shape().dim(i).size();
148068e17d497497119c24ad506dac4e34e127cf836cJie      input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
1481825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1482825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1483825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(ben,jie): proper way to restore input tensor name?
1484825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto input_tensor_name = node_name;
1485825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (output_idx != 0)
1486825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      input_tensor_name = node_name + ":" + std::to_string(output_idx);
1487825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1488825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::ITensor* input_tensor = converter.network()->addInput(
1489825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
1490825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1491825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!input_tensor)
1492825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument(
1493825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          "Failed to create Input layer");
1494bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Input tensor name :" << input_tensor_name;
1495825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1496825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
1497825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::AlreadyExists(
1498bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Output tensor already exists for op: " + input_tensor_name);
1499825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1500825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1501bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "Finished sorting";
1502825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1503825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (const tensorflow::Node* node : order) {
1504bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const tensorflow::NodeDef& node_def = node->def();
1505bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
1506825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TF_RETURN_IF_ERROR(converter.convert_node(node_def));
1507825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1508825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1509bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "Finished conversion";
1510825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1511825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Gather output metadata
1512bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<string> output_names;
1513825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::DataType> output_dtypes;
1514825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (std::pair<int, int> const& output : output_inds) {
1515825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int node_id = output.first;
1516825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int output_idx = output.second;
1517825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::Node* node = graph.FindNodeId(node_id);
1518bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    string op_name = node->name();
1519bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    string tensor_name = op_name;
1520825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (output_idx != 0)
1521825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      tensor_name = tensor_name + ":" + std::to_string(output_idx);
1522bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Output tensor name: " << tensor_name;
1523825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    output_names.push_back(tensor_name);
1524825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto tensor_or_weights = converter.get_tensor(tensor_name);
1525825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!tensor_or_weights.is_tensor()) {
1526825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument(
1527825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          "Output node is weights not tensor");
1528825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1529825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
1530825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!tensor) {
1531825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::NotFound("Output tensor not found: " +
1532825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                          tensor_name);
1533825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1534825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    converter.network()->markOutput(*tensor);
1535825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::DataType tf_dtype = node->output_type(output_idx);
1536825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    output_dtypes.push_back(tf_dtype);
1537825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
1538bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
1539825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensor->setType(trt_dtype);
1540825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1541825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1542bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "Finished output";
1543cd63c718be123324b6c39e0f8fbe453319799746Jie  // TODO(jie): static_id is not thread safe.
1544599eadc299ae680bfb569ace4278b2eb262ecc44Sami Kama  static int static_id = 0;
1545825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1546825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Build the engine
1547825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  trt_builder->setMaxBatchSize(max_batch_size);
1548f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
1549bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(0) << "Starting build engine " << static_id;
1550825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(ben,jie): half2 and int8 mode support
1551bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  string engine_plan_string;
1552825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  {
1553825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto trt_engine =
1554825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        infer_object(trt_builder->buildCudaEngine(*converter.network()));
1555bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(0) << "Built network";
1556825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto engine_plan = infer_object(trt_engine->serialize());
1557bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(0) << "Serialized engine";
1558825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const char* engine_plan_data =
1559825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        static_cast<const char*>(engine_plan->data());
1560bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    engine_plan_string =
1561bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        string(engine_plan_data, engine_plan_data + engine_plan->size());
1562825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1563825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1564bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(0) << "Finished engine";
1565825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1566825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Build the TRT op
1567825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(sami,ben,jie): proper naming!
1568825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::NodeDefBuilder op_builder(
1569bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
1570825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
1571825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (size_t i = 0; i < input_names.size(); ++i) {
1572825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int output_idx = input_inds.at(i).second;
1573bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // We wired up the input here already, it is redundant to do it again in
1574bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // ConvertSubGraphToTensorRT(convert_graph.cc)
1575e01844e65e0dbd2682a894946bec7f072d36fa27Jie    auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
1576e01844e65e0dbd2682a894946bec7f072d36fa27Jie        input_names.at(i), output_idx, input_dtypes.at(i));
1577825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    income_edges.push_back(incoming_edge);
1578825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1579e01844e65e0dbd2682a894946bec7f072d36fa27Jie  tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
1580e01844e65e0dbd2682a894946bec7f072d36fa27Jie      income_edges);
1581825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  op_builder.Input(input_list);
1582825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1583bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(0) << "Finished op preparation";
1584825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1585825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto status = op_builder.Attr("serialized_engine", engine_plan_string)
1586825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Attr("input_nodes", input_names)
1587825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Attr("output_nodes", output_names)
1588825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Attr("OutT", output_dtypes)
1589825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Finalize(trt_node);
1590825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1591f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(0) << status.ToString() << " finished op building";
1592825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1593825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1594825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1595825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1596825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}  // namespace convert
1597825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}  // namespace tensorrt
15986908cc233c679b8fe61d99a30d3828362caf47beSami Kama}  // namespace tensorflow
1599ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie
1600e01844e65e0dbd2682a894946bec7f072d36fa27Jie#endif  // GOOGLE_TENSORRT
1601e01844e65e0dbd2682a894946bec7f072d36fa27Jie#endif  // GOOGLE_CUDA
1602