convert_nodes.cc revision cd63c718be123324b6c39e0f8fbe453319799746
16908cc233c679b8fe61d99a30d3828362caf47beSami Kama/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
3825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaLicensed under the Apache License, Version 2.0 (the "License");
4825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamayou may not use this file except in compliance with the License.
5825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaYou may obtain a copy of the License at
6825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
7825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    http://www.apache.org/licenses/LICENSE-2.0
8825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
9825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaUnless required by applicable law or agreed to in writing, software
10825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamadistributed under the License is distributed on an "AS IS" BASIS,
11825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaSee the License for the specific language governing permissions and
13825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamalimitations under the License.
14825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama==============================================================================*/
15825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
16d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
17d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney
18825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <algorithm>
19825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <list>
20825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <map>
21825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <memory>
22825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <set>
23825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <unordered_map>
24825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <utility>
25825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <vector>
26825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
27825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/framework/node_def_builder.h"
28bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/framework/types.h"
29825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/algorithm.h"
30825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/graph.h"
31825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/graph_constructor.h"
32825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/lib/core/errors.h"
33825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/lib/core/status.h"
34bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/lib/strings/strcat.h"
35825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/platform/logging.h"
36bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/platform/tensor_coding.h"
37bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/platform/types.h"
38825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
39ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#if GOOGLE_CUDA
40ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#if GOOGLE_TENSORRT
41ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
428e03944589542bd64559d68989bca4a4705eed93gracehoney#include "tensorrt/include/NvInfer.h"
438e03944589542bd64559d68989bca4a4705eed93gracehoney
44d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney//  Check if the types are equal. Cast to int first so that failure log message
45d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney//  would work!
46825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
478e03944589542bd64559d68989bca4a4705eed93gracehoney
486908cc233c679b8fe61d99a30d3828362caf47beSami Kamanamespace tensorflow {
49825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace tensorrt {
50825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace convert {
51825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
52825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace {
53825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
54bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
55bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                       nvinfer1::DataType* trt_dtype) {
56825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (tf_dtype) {
57825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT:
58825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *trt_dtype = nvinfer1::DataType::kFLOAT;
59825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
60825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_INT8:
61825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *trt_dtype = nvinfer1::DataType::kINT8;
62825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
63825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_HALF:
64825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *trt_dtype = nvinfer1::DataType::kHALF;
65825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
66825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
67825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument("Unsupported data type");
68825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
69825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
70825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
71825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
72bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
73825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims dims;
74825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  dims.nbDims = tensor.dims();
75825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int i = 0; i < dims.nbDims; i++) {
76825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    dims.d[i] = tensor.dim_size(i);
77825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
78825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return dims;
79825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
80825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
81bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline int64_t GetShapeSize(nvinfer1::Dims shape) {
82825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Returns total number of elements in shape
83825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int64_t count = 1;
84825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int d = 0; d < shape.nbDims; ++d) {
85825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    count *= shape.d[d];
86825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
87825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return count;
88825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
89825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
90bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystatic std::vector<std::pair<int, int>> CreateSamePadding(
91bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
92bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const std::vector<int64_t>& input_dims) {
93bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<std::pair<int, int>> padding(input_dims.size());
94bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  CHECK_EQ((size_t)stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
95825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
96bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  for (size_t i = 0; i < input_dims.size(); ++i) {
97bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Formula to calculate the padding
98bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
99bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            input_dims[i];
100825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    p = (p > 0) ? p : 0;
101825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
102bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Right precedence padding, like in TensorFlow
103825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int left = p / 2;
104825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int right = p - left;
105825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
106f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
107bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
108f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << "kernel: " << kernel.d[i];
109825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    padding[i] = {left, right};
110825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
111825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return padding;
112825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
113825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
114825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TRT_ShapedWeights {
115825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
116bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights(tensorflow::DataType type, const void* values,
117cfa374cefe132be886c26a374c51454177c68868gracehoney                    nvinfer1::Dims shape,
118cfa374cefe132be886c26a374c51454177c68868gracehoney                    const std::vector<char>* owned_values = nullptr)
119bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      : shape_(shape),
120bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        type_(type),
121bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        values_(values),
122cfa374cefe132be886c26a374c51454177c68868gracehoney        owned_values_(owned_values ? *owned_values : std::vector<char>({})),
123cd63c718be123324b6c39e0f8fbe453319799746Jie        empty_weight_flag_(false) {
124bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Note: this->shape.type[] is not used
125bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  }
126bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
127bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  explicit TRT_ShapedWeights(tensorflow::DataType type)
128cfa374cefe132be886c26a374c51454177c68868gracehoney      : shape_(),
129cfa374cefe132be886c26a374c51454177c68868gracehoney        type_(type),
130bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        values_(nullptr),
131cfa374cefe132be886c26a374c51454177c68868gracehoney        owned_values_(),
132cd63c718be123324b6c39e0f8fbe453319799746Jie        empty_weight_flag_(true) {}
133bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
134cfa374cefe132be886c26a374c51454177c68868gracehoney  TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
135cfa374cefe132be886c26a374c51454177c68868gracehoney      : shape_(rhs.shape_),
136cfa374cefe132be886c26a374c51454177c68868gracehoney        type_(rhs.type_),
137cfa374cefe132be886c26a374c51454177c68868gracehoney        values_(rhs.values_),
138cfa374cefe132be886c26a374c51454177c68868gracehoney        owned_values_(rhs.owned_values_),
139cd63c718be123324b6c39e0f8fbe453319799746Jie        empty_weight_flag_(rhs.empty_weight_flag_) {}
140bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
141825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int64_t count() const {
142825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int64_t c = 1;
143825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
144825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return c;
145825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
146bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
147bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  nvinfer1::Weights GetWeightsForTRT() const {
148825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
149bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_CHECK_OK(ConvertDType(type_, &trt_type));
150cd63c718be123324b6c39e0f8fbe453319799746Jie    if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
151825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
152825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // Note: this->shape.type[] is not used
153cfa374cefe132be886c26a374c51454177c68868gracehoney    return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
154cfa374cefe132be886c26a374c51454177c68868gracehoney  }
155cfa374cefe132be886c26a374c51454177c68868gracehoney
156cfa374cefe132be886c26a374c51454177c68868gracehoney  const void* GetValues() const {
157cfa374cefe132be886c26a374c51454177c68868gracehoney    if (values_) return values_;
158cfa374cefe132be886c26a374c51454177c68868gracehoney    if (owned_values_.size()) return owned_values_.data();
159cfa374cefe132be886c26a374c51454177c68868gracehoney    return nullptr;
160cfa374cefe132be886c26a374c51454177c68868gracehoney  }
161cfa374cefe132be886c26a374c51454177c68868gracehoney
162cfa374cefe132be886c26a374c51454177c68868gracehoney  void SetValues(const void* values) {
163cfa374cefe132be886c26a374c51454177c68868gracehoney    values_ = values;
164cfa374cefe132be886c26a374c51454177c68868gracehoney    owned_values_.clear();
165825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
166bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
167825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  size_t size_bytes() const {
16875adab6104362d71ce28b0269bf31fd30471b1b6Jie    int type_size = tensorflow::DataTypeSize(this->type_);
16975adab6104362d71ce28b0269bf31fd30471b1b6Jie    return this->count() * type_size;
170825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
171bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
172bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Default converter
173bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
174bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
175bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  nvinfer1::Dims shape_;
176bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  tensorflow::DataType type_;
177cfa374cefe132be886c26a374c51454177c68868gracehoney
178cfa374cefe132be886c26a374c51454177c68868gracehoney private:
179bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  const void* values_;
180cfa374cefe132be886c26a374c51454177c68868gracehoney  std::vector<char> owned_values_;
181cd63c718be123324b6c39e0f8fbe453319799746Jie  bool empty_weight_flag_;
182825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
183825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
184825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TRT_TensorOrWeights {
185825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
186825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
187cd63c718be123324b6c39e0f8fbe453319799746Jie      : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
188cfa374cefe132be886c26a374c51454177c68868gracehoney  explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
189cd63c718be123324b6c39e0f8fbe453319799746Jie      : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
190cfa374cefe132be886c26a374c51454177c68868gracehoney  TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
191cd63c718be123324b6c39e0f8fbe453319799746Jie      : tensor_(rhs.tensor_),
192cd63c718be123324b6c39e0f8fbe453319799746Jie        weights_(rhs.weights_),
193cd63c718be123324b6c39e0f8fbe453319799746Jie        variant_(rhs.variant_) {}
194bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  ~TRT_TensorOrWeights() {}
195bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
196cd63c718be123324b6c39e0f8fbe453319799746Jie  bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
197cd63c718be123324b6c39e0f8fbe453319799746Jie  bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
198bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
199825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* tensor() {
200cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_tensor(), true);
201cd63c718be123324b6c39e0f8fbe453319799746Jie    return tensor_;
202825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
203cfa374cefe132be886c26a374c51454177c68868gracehoney  const nvinfer1::ITensor* tensor() const {
204cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_tensor(), true);
205cd63c718be123324b6c39e0f8fbe453319799746Jie    return tensor_;
206825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
207825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights& weights() {
208cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_weights(), true);
209cd63c718be123324b6c39e0f8fbe453319799746Jie    return weights_;
210825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
211bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  const TRT_ShapedWeights& weights() const {
212cfa374cefe132be886c26a374c51454177c68868gracehoney    CHECK_EQ(is_weights(), true);
213cd63c718be123324b6c39e0f8fbe453319799746Jie    return weights_;
214825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
215825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims shape() const {
216cfa374cefe132be886c26a374c51454177c68868gracehoney    if (is_tensor()) {
217cfa374cefe132be886c26a374c51454177c68868gracehoney      return tensor()->getDimensions();
218825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else {
219cfa374cefe132be886c26a374c51454177c68868gracehoney      return weights().shape_;
220825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
221825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
222825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
223bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney private:
224cd63c718be123324b6c39e0f8fbe453319799746Jie  nvinfer1::ITensor* tensor_;
225cd63c718be123324b6c39e0f8fbe453319799746Jie  TRT_ShapedWeights weights_;
226cd63c718be123324b6c39e0f8fbe453319799746Jie  enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
227825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
228825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
229825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TFAttrs {
230825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
231bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
232bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (const auto& attr : tf_node.attr()) {
233cd63c718be123324b6c39e0f8fbe453319799746Jie      attrs_.insert({attr.first, &attr.second});
234825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
235825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
236cd63c718be123324b6c39e0f8fbe453319799746Jie  bool count(string key) const { return attrs_.count(key); }
237bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  tensorflow::AttrValue const* at(string key) const {
238cd63c718be123324b6c39e0f8fbe453319799746Jie    if (!attrs_.count(key)) {
2398e03944589542bd64559d68989bca4a4705eed93gracehoney      LOG(FATAL) << "Attribute not found: " << key;
240825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
241cd63c718be123324b6c39e0f8fbe453319799746Jie    return attrs_.at(key);
242825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
243825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
244bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  T get(string key) const;
245825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
246bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  T get(string key, const T& default_value) const {
247cd63c718be123324b6c39e0f8fbe453319799746Jie    return attrs_.count(key) ? this->get<T>(key) : default_value;
248825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
249bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
250bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney private:
251bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
252cd63c718be123324b6c39e0f8fbe453319799746Jie  AttrMap attrs_;
253825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
254825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
255825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
256bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystring TFAttrs::get<string>(string key) const {
257825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return this->at(key)->s();
258825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
259bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
260825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
261bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystd::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
262825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto attr = this->at(key)->list().i();
263825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return std::vector<int>(attr.begin(), attr.end());
264825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
265bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney
266825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
267bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneynvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
268825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto values = this->get<std::vector<int>>(key);
269825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims dims;
270825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  dims.nbDims = values.size();
271825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::copy(values.begin(), values.end(), dims.d);
272825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Note: No dimension type information is included
273825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return dims;
274825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
2756908cc233c679b8fe61d99a30d3828362caf47beSami Kama
276825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
277bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneynvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
278825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
279bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
280825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return trt_dtype;
281825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
2826908cc233c679b8fe61d99a30d3828362caf47beSami Kama
283825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <>
284bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
285825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return this->at(key)->type();
286825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
287825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
288825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <typename T>
289cfa374cefe132be886c26a374c51454177c68868gracehoneyvoid Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
290825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama              nvinfer1::DimsNCHW istrides, T* odata,
291825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama              nvinfer1::DimsNCHW ostrides) {
292825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int n = 0; n < shape.n(); ++n) {
293825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int c = 0; c < shape.c(); ++c) {
294825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      for (int h = 0; h < shape.h(); ++h) {
295825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        for (int w = 0; w < shape.w(); ++w) {
296825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
297825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
298825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                          h * istrides.h() + w * istrides.w()];
299825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        }
300825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
301825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
302825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
303825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
304825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
305bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyvoid ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
306bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                       TRT_ShapedWeights* oweights) {
307825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights.type_, oweights->type_);
308825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
309825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int r = iweights.shape_.d[0];
310825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int s = iweights.shape_.d[1];
311825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int c = iweights.shape_.d[2];
312825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int k = iweights.shape_.d[3];
313825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[0] = k;
314825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[1] = c;
315825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[2] = r;
316825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  oweights->shape_.d[3] = s;
317825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
318825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
319825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (iweights.type_) {
320825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT:
321cfa374cefe132be886c26a374c51454177c68868gracehoney      Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
322cfa374cefe132be886c26a374c51454177c68868gracehoney               istrides,
323cfa374cefe132be886c26a374c51454177c68868gracehoney               static_cast<float*>(const_cast<void*>(oweights->GetValues())),
324cfa374cefe132be886c26a374c51454177c68868gracehoney               ostrides);
325825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
326825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
327825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
328825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
329825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
330825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
331825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamastruct InferDeleter {
332825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
333825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  void operator()(T* obj) const {
334825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (obj) {
335825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      obj->destroy();
336825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
337825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
338825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
339825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
340825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <typename T>
341825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamainline std::shared_ptr<T> infer_object(T* obj) {
342825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return std::shared_ptr<T>(obj, InferDeleter());
343825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
344825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
345825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// Logger for GIE info/warning/errors
346825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass Converter;
347825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
348825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamausing OpConverter =
349bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
350825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     std::vector<TRT_TensorOrWeights> const&,
351825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     std::vector<TRT_TensorOrWeights>*)>;
352825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
353825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass Converter {
354cd63c718be123324b6c39e0f8fbe453319799746Jie  std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
355cd63c718be123324b6c39e0f8fbe453319799746Jie  std::unordered_map<string, OpConverter> op_registry_;
356cd63c718be123324b6c39e0f8fbe453319799746Jie  nvinfer1::INetworkDefinition* trt_network_;
357cd63c718be123324b6c39e0f8fbe453319799746Jie  std::list<std::vector<uint8_t>> temp_bufs_;
358825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
359825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  void register_op_converters();
360825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
361825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<TRT_TensorOrWeights> get_inputs(
362bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      const tensorflow::NodeDef& node_def) {
363825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> inputs;
364bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (const auto& input_name : node_def.input()) {
365bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Retrieve input: " << input_name;
366cd63c718be123324b6c39e0f8fbe453319799746Jie      inputs.push_back(trt_tensors_.at(input_name));
367825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
368825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return inputs;
369825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
370825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
371825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public:
372825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  explicit Converter(nvinfer1::INetworkDefinition* trt_network)
373cd63c718be123324b6c39e0f8fbe453319799746Jie      : trt_network_(trt_network) {
374825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    this->register_op_converters();
375825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
376825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
377825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
378825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     nvinfer1::Dims shape) {
379825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TRT_ShapedWeights weights(type, nullptr, shape);
38075adab6104362d71ce28b0269bf31fd30471b1b6Jie    // TODO(jie): check weights size_bytes. 0 means type error
381cd63c718be123324b6c39e0f8fbe453319799746Jie    temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
382cd63c718be123324b6c39e0f8fbe453319799746Jie    weights.SetValues(temp_bufs_.back().data());
383825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return weights;
384825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
385825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
386bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
387825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return this->get_temp_weights(weights.type_, weights.shape_);
388825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
389825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
390bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
391825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
392bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    string op = node_def.op();
393cd63c718be123324b6c39e0f8fbe453319799746Jie    if (!op_registry_.count(op)) {
394825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::Unimplemented(
395bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "No converter registered for op: " + op);
396825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
397cd63c718be123324b6c39e0f8fbe453319799746Jie    OpConverter op_converter = op_registry_.at(op);
398825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> outputs;
399825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
400825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (size_t i = 0; i < outputs.size(); ++i) {
401825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      TRT_TensorOrWeights output = outputs.at(i);
402825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      // TODO(jie): tf protobuf seems to be omitting the :0 suffix
403bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      string output_name = node_def.name();
404825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (i != 0) output_name = output_name + ":" + std::to_string(i);
405825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (output.is_tensor()) {
406825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        output.tensor()->setName(output_name.c_str());
407825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
408bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Write out tensor: " << output_name;
409cd63c718be123324b6c39e0f8fbe453319799746Jie      if (!trt_tensors_.insert({output_name, output}).second) {
410825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return tensorflow::errors::AlreadyExists(
411bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            "Output tensor already exists for op: " + op);
412825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
413825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
414825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::Status::OK();
415825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
416825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
417cd63c718be123324b6c39e0f8fbe453319799746Jie  nvinfer1::INetworkDefinition* network() { return trt_network_; }
418825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
419bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_TensorOrWeights get_tensor(string name) {
420cd63c718be123324b6c39e0f8fbe453319799746Jie    if (!trt_tensors_.count(name)) {
421825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return TRT_TensorOrWeights(nullptr);
422825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
423cd63c718be123324b6c39e0f8fbe453319799746Jie    return trt_tensors_.at(name);
424825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
425825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
426bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
427cd63c718be123324b6c39e0f8fbe453319799746Jie    return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
428825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
429825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
430bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
431825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     std::vector<int> order) {
432825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto dims = input_tensor->getDimensions();
433825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
434825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): change the return to status and properly exit
435825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (order.size() - 1 != size_t(dims.nbDims))
436bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      LOG(ERROR) << "Dimension does not match, fail gracefully";
437825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
438825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
439825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::Permutation permutation;
440825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int32_t i = 0; i < dims.nbDims; ++i) {
441825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      permutation.order[i] = order[i + 1] - 1;
442825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
443825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    layer->setFirstTranspose(permutation);
444825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
445bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    nvinfer1::Dims reshape_dims;
446bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    reshape_dims.nbDims = dims.nbDims;
447bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
448bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      reshape_dims.d[i] = 0;
449bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      reshape_dims.type[i] = dims.type[i];
450825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
451bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    layer->setReshapeDimensions(reshape_dims);
452825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return layer->getOutput(0);
453825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
454825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
455825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
4566908cc233c679b8fe61d99a30d3828362caf47beSami Kama// ****************************************************************************
457e01844e65e0dbd2682a894946bec7f072d36fa27Jie// Constant folding functions
458e01844e65e0dbd2682a894946bec7f072d36fa27Jie// TODO(jie): once optimizer kicks in, we should have done constant folding
459e01844e65e0dbd2682a894946bec7f072d36fa27Jie// there.
4606908cc233c679b8fe61d99a30d3828362caf47beSami Kama//*****************************************************************************/
461825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamastruct LambdaFactory {
462825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
463825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  OP_CATEGORY op;
464825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
465825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
466825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T)> unary() {
467825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
468825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::RSQRT: {
469f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama        VLOG(2) << "RSQRT GETS DONE";
470825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T t) -> T { return 1.0 / std::sqrt(t); };
471825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
472825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::NEG:
473825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T t) -> T { return -t; };
474825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
475bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
476825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return nullptr;
477825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
478825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
479825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
480825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
481825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T, T)> binary() {
482825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
483825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::ADD:
484825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T l, T r) -> T { return l + r; };
485825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::SUB:
486825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T l, T r) -> T { return l - r; };
487825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::MUL:
488825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [](T l, T r) -> T { return l * r; };
489825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
490bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
491825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
492825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return [](T l, T r) -> T {
493825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "Unsupported op type ";
494825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return l;
495825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    };
496825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
497825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
498825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
499825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T)> broadcast_r(T val) {
500f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "LAMBDA VAL : " << val;
501825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
502825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::ADD:
503825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
504f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
505825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return l + val;
506825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
507bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // Return [val](T l)-> T {return l+val;};
508825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::SUB:
509825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
510f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
511825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return l - val;
512825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
513825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::MUL:
514825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
515f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
516825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return l * val;
517825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
518825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
519bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
520825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
521825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return [val](T l) -> T {
522825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "Unsupported op type ";
523825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return l;
524825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    };
525825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
526825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
527825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  template <typename T>
528825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::function<T(T)> broadcast_l(T val) {
529f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "LAMBDA VAL : " << val;
530825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    switch (op) {
531825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::ADD:
532825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
533f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
534825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return val + l;
535825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
536825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::SUB:
537825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
538f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
539825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return val - l;
540825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
541825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      case OP_CATEGORY::MUL:
542825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        return [val](T l) -> T {
543f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "LAMBDA VAL : " << val;
544825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return val * l;
545825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        };
546825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      default:
547bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
548825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
549825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return [val](T l) -> T {
550825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      LOG(FATAL) << "Unsupported op type ";
551825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return l;
552825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    };
553825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
554825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama};
555825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
556bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
557825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                TRT_ShapedWeights* oweights,
558825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                LambdaFactory unary_op) {
559825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights.type_, oweights->type_);
560825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (iweights.type_) {
561825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT: {
562cfa374cefe132be886c26a374c51454177c68868gracehoney      auto inp = static_cast<float const*>(iweights.GetValues());
563cfa374cefe132be886c26a374c51454177c68868gracehoney      auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
564825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
565825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
566825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
567825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
568bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      return tensorflow::errors::Unimplemented(
569bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Data type not supported: " +
570bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          tensorflow::DataTypeString(iweights.type_));
571825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
572825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
573825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
574825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
575bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
576bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const TRT_ShapedWeights& iweights_r,
577825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 TRT_ShapedWeights* oweights,
578825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 LambdaFactory binary_op) {
579bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Assume iweights_l.type == iweight_r.type
580825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights_l.type_, oweights->type_);
581825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(iweights_r.type_, oweights->type_);
582f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(2) << "SANITY CHECK!";
583825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
584825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  switch (iweights_l.type_) {
585825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    case tensorflow::DataType::DT_FLOAT: {
586cfa374cefe132be886c26a374c51454177c68868gracehoney      auto inp_l = static_cast<const float*>(iweights_l.GetValues());
587cfa374cefe132be886c26a374c51454177c68868gracehoney      auto inp_r = static_cast<const float*>(iweights_r.GetValues());
588cfa374cefe132be886c26a374c51454177c68868gracehoney      auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
589825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
590825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (iweights_l.count() != iweights_r.count()) {
591bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        // We only supports broadcast of RankZero
592825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        if (iweights_l.count() == 1) {
593f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "I bet it is not working!" << (*inp_l);
594825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          std::transform(inp_r, inp_r + iweights_r.count(), oup,
595825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                         binary_op.broadcast_l<float>(*inp_l));
596825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        } else if (iweights_r.count() == 1) {
597f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          VLOG(2) << "I bet it is not working!" << (*inp_r);
598825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          std::transform(inp_l, inp_l + iweights_l.count(), oup,
599825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                         binary_op.broadcast_r<float>(*inp_r));
600825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        } else {
601825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return tensorflow::errors::Unimplemented(
602825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama              "Binary op with non-rankZero broadcast not supported");
603825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        }
604825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      } else {
605825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
606825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                       binary_op.binary<float>());
607825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
608825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      break;
609825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
610825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    default:
611bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      return tensorflow::errors::Unimplemented(
612bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Data type not supported: " +
613bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          tensorflow::DataTypeString(iweights_l.type_));
614825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
615825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
616825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
617825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
618825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
619825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConstantFoldUnary(
620bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
621825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
622825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
623825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_input = inputs.at(0).weights();
624825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
625bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Allocate output weights
626825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
627825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
628825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
629bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
630bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
631bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
632825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(weights_input.type_,
633825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama           TFAttrs(node_def).get<tensorflow::DataType>("T"));
634825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
635825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Maybe I should do a switch
636825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  LambdaFactory unary_op;
637825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Rsqrt") {
638bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Compute rsqrt
639825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
640825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
641bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // PAss the output
642825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (ret == tensorflow::Status::OK()) {
643825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      outputs->push_back(TRT_TensorOrWeights(weights_output));
644825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
645825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return ret;
646825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
647825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Binary op not supported: " +
648825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                             node_def.op());
649825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
650825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
651825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
652825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// TODO(jie,ben) broadcast is needed yet not implemented
653825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// Let's get the simple stuff working first. Maybe we should fall bakc to TF
654825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama//   approach for constant folding
655825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConstantFoldBinary(
656bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
657825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
658825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
659825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
660825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
661825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
662bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
663825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
664825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
665825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
666825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
667825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Binary op implicit broadcast not supported: " + node_def.op());
668825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
669825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): constant fold should really fall back to TF.
670bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  int nb_dims = weights_input_l.shape_.nbDims;
671825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims output_shape;
672bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  output_shape.nbDims = nb_dims;
673bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "nb_dims: " << nb_dims
674bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          << ", the other: " << weights_input_r.shape_.nbDims;
675bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  for (int i = 0; i < nb_dims; i++) {
676825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
677825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      output_shape.d[i] = weights_input_l.shape_.d[i];
678825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else if (weights_input_l.shape_.d[i] == 1 ||
679825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama               weights_input_r.shape_.d[i] == 1) {
680825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      output_shape.d[i] =
681825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
682825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else {
683825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::Unimplemented(
684825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          "Binary op with incompatible shape at, " + node_def.op());
685825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
686f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "left: " << weights_input_l.shape_.d[i]
687f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << "right: " << weights_input_r.shape_.d[i]
688f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << "output: " << output_shape.d[i];
689825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
690825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
691825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
692bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
693825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
694bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
695825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
696825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
697bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Allocate output weights
698825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
699825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
700825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Maybe I should do a switch
701825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  LambdaFactory binary_op;
702825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Sub") {
703825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
704825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Mul") {
705825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
706825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Add") {
707825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
708825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
709825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Binary op not supported: " +
710825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                             node_def.op());
711825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
712825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
713825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                           binary_op);
714825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
715bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
716825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (ret == tensorflow::Status::OK()) {
717825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    outputs->push_back(TRT_TensorOrWeights(weights_output));
718825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
719825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
720825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return ret;
721825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
722825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
723bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney// TODO(jie): broadcast is needed yet not implemented.
724bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney// Only implemented channel wise for the time being
725825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status BinaryTensorOpWeight(
726bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
727825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
728825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
729825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
730bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
731bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
732825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
733bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
734825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
735bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  CHECK_EQ_TYPE(tensor->getType(), dtype);  // Cast to int for error messages
736825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType ttype;
737bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
738bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  CHECK_EQ_TYPE(ttype, dtype);  // Cast to int for error message
739825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
740bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check scale mode
741825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims_w = weights.shape_;
742825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims_t = tensor->getDimensions();
743825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
744bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Default to channel-wise
745825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
746825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
747825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (weights.count() == 1) {
748f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "UNIFORM";
749825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    scale_mode = nvinfer1::ScaleMode::kUNIFORM;
750825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
751bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // No broadcasting on Batch dimension;
752e01844e65e0dbd2682a894946bec7f072d36fa27Jie    assert(dims_w.d[0] == 1);
753825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
754bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Broadcasting on Channel dimension only allowed in kUNIFORM
755e01844e65e0dbd2682a894946bec7f072d36fa27Jie    assert(dims_w.d[1] == dims_t.d[0]);
756e01844e65e0dbd2682a894946bec7f072d36fa27Jie    assert(dims_w.nbDims == dims_t.nbDims);
757825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
758bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Default is element;
759e01844e65e0dbd2682a894946bec7f072d36fa27Jie    for (int i = 2; i < dims_w.nbDims; i++) {
760e01844e65e0dbd2682a894946bec7f072d36fa27Jie      if (dims_w.d[i] != dims_t.d[i - 1]) {
761825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        scale_mode = nvinfer1::ScaleMode::kCHANNEL;
762825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        break;
763825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
764825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
765825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
766825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
767e01844e65e0dbd2682a894946bec7f072d36fa27Jie      for (int i = 2; i < dims_w.nbDims; i++) {
768e01844e65e0dbd2682a894946bec7f072d36fa27Jie        if (dims_w.d[i] != 1)
769825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          return tensorflow::errors::InvalidArgument(
770e01844e65e0dbd2682a894946bec7f072d36fa27Jie              "Weight shape not compatible at, " + node_def.name());
771825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
772825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
773825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
774825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
775bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Prepare weights
776bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights shift_weights(weights.type_);
777bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights scale_weights(weights.type_);
778bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  TRT_ShapedWeights power_weights(weights.type_);
779825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
780825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Maybe I should do a switch
781825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Sub") {
782825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
783825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    LambdaFactory unary_op;
784825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
785bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
786bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    shift_weights = neg_weights;
787825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Mul") {
788bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    scale_weights = weights;
789825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (node_def.op() == "Add") {
790bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    shift_weights = weights;
791825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
792825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Binary op not supported: " +
793825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                             node_def.op());
794825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
795825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
796825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
797bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
798bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      scale_weights, power_weights);
799825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
800825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
801825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
802bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
803825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
804825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
805825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
806825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
807825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status BinaryTensorOpTensor(
808bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
809825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
810825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
811bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
812bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Add", nvinfer1::ElementWiseOperation::kSUM},
813bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Mul", nvinfer1::ElementWiseOperation::kPROD},
814bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // {"max", nvinfer1::ElementWiseOperation::kMAX},
815bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // {"min", nvinfer1::ElementWiseOperation::kMIN},
816bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Sub", nvinfer1::ElementWiseOperation::kSUB},
817bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      {"Div", nvinfer1::ElementWiseOperation::kDIV},
818bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  };
819825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
820825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // FIXME assume type matches input weights
821bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
822825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
823bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Maybe this part has to be moved into the block of rsqrt later
824825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
825825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
826bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Check type consistency
827825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ_TYPE(tensor_l->getType(), dtype);
828825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  CHECK_EQ_TYPE(tensor_r->getType(), dtype);
829825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto op_pair = ops.find(node_def.op());
830825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (op_pair == ops.end())
831d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney    return tensorflow::errors::Unimplemented(
832d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney        "binary op: " + node_def.op() +
833d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney        " not supported at: " + node_def.name());
834825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
835825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
836825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor_l),
837825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
838825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
839825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
840825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
841bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
842825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
843825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
844825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
845825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
846825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPlaceholder(
847bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
848825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
849825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
850f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(2) << "Placeholder should have been replace already";
851bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
852825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // OK this make sense since we are supposed to replace it with input
853825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
854825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
855825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
856825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
857825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  dims.nbDims--;
858825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
859825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
860825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output =
861825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
862825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!output) {
863825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument("Failed to create Input layer");
864825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
865825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output));
866825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
867825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
868825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
869825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertConv2D(Converter& ctx,
870bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const tensorflow::NodeDef& node_def,
871bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const std::vector<TRT_TensorOrWeights>& inputs,
872825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights>* outputs) {
873825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
874825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): handle NHWC/NCHW transpose;
875825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
876825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
877bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  ReorderRSCKToKCRS(weights_rsck, &weights);
878825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights biases(weights.type_);
879825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int noutput = weights.shape_.d[0];
880825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW kernel_size;
881825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  kernel_size.h() = weights.shape_.d[2];
882825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  kernel_size.w() = weights.shape_.d[3];
883825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
884825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
885825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int h_index = 2;
886825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int w_index = 3;
887bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  auto data_format = attrs.get<string>("data_format");
888825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
889bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
890825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 1, 2});
891825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    h_index = 1;
892825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    w_index = 2;
893825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it
894825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
895e01844e65e0dbd2682a894946bec7f072d36fa27Jie
896825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): stride. (NHWC/NCHW)
897825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tf_stride = attrs.get<std::vector<int>>("strides");
898825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
899825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
900825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tensor_dim = tensor->getDimensions();
901825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<std::pair<int, int>> padding;
902825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): padding.
903bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (attrs.get<string>("padding") == "SAME") {
904825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // This is NCHW tensor with no batch dimension.
905825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  1 -> h
906825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  2 -> w
907bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    padding = CreateSamePadding(
908e01844e65e0dbd2682a894946bec7f072d36fa27Jie        stride, kernel_size,
909e01844e65e0dbd2682a894946bec7f072d36fa27Jie        {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
910825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
911825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    padding = {{0, 0}, {0, 0}};
912825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
913825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
914825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (padding[0].first != padding[0].second ||
915825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      padding[1].first != padding[1].second) {
916825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): handle asymmetric padding
917bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
918f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << padding[1].first << padding[1].second;
91924e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie
92024e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    auto dim_before = tensor->getDimensions();
921f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
922f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << dim_before.d[2] << ", " << dim_before.d[3];
923bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    auto pad_layer = ctx.network()->addPadding(
924825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        *const_cast<nvinfer1::ITensor*>(tensor),
92524e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].first, padding[1].first),
92624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].second, padding[1].second));
92724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    padding = {{0, 0}, {0, 0}};
928bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = pad_layer->getOutput(0);
92924e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    auto dim_after = tensor->getDimensions();
930f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
931f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << dim_after.d[2] << ", " << dim_after.d[3];
932825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
933825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
934825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IConvolutionLayer* layer =
935825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
936825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                    noutput, kernel_size, weights, biases);
937825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
938825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setStride(stride);
939825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setPadding({padding[0].first, padding[1].first});
940825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setName(node_def.name().c_str());
941825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
942825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
94324e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie  auto dim_after = output_tensor->getDimensions();
944f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
945f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama          << dim_after.d[2] << ", " << dim_after.d[3];
94624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie
947825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
948825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it back!
949bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
950825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
951f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
952825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
953825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
954825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
955825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
956825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
957825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPool(Converter& ctx,
958bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                               const tensorflow::NodeDef& node_def,
959825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                               std::vector<TRT_TensorOrWeights> const& inputs,
960825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                               std::vector<TRT_TensorOrWeights>* outputs) {
961825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
962825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
963825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
964825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int h_index = 2;
965825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int w_index = 3;
966bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  auto data_format = attrs.get<string>("data_format");
967825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
968825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    h_index = 1;
969825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    w_index = 2;
970bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
971825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 1, 2});
972825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
973f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
974825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
975825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::PoolingType type;
976825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): support other pooling type
977825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "MaxPool")
978825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    type = nvinfer1::PoolingType::kMAX;
979825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  else
980bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    return tensorflow::errors::Unimplemented("Only supports Max pool");
981825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
982825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): NCHW
983825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tf_stride = attrs.get<std::vector<int>>("strides");
984825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
985825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
986825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tf_kernel = attrs.get<std::vector<int>>("ksize");
987825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
988825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
989825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto tensor_dim = tensor->getDimensions();
990825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<std::pair<int, int>> padding;
991825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): padding.
992bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (attrs.get<string>("padding") == "SAME") {
993825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // This is NCHW tensor with no batch dimension.
994825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  1 -> h
995825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    //  2 -> w
996bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    padding = CreateSamePadding(
997825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        stride, ksize,
998825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
999bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  } else if (attrs.get<string>("padding") == "VALID") {
1000825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // No padding for valid padding here
1001bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
1002825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    padding = {{0, 0}, {0, 0}};
1003825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1004825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1005825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Current MaxPool cannot support padding other than SAME");
1006825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1007825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1008825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (padding[0].first != padding[0].second ||
1009825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      padding[1].first != padding[1].second) {
1010825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): handle asymmetric padding
1011bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
1012f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << padding[1].first << padding[1].second;
1013bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    auto pad_layer = ctx.network()->addPadding(
1014825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        *const_cast<nvinfer1::ITensor*>(tensor),
101524e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].first, padding[1].first),
101624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie        nvinfer1::DimsHW(padding[0].second, padding[1].second));
101724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie    padding = {{0, 0}, {0, 0}};
1018bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = pad_layer->getOutput(0);
1019825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1020825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1021825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
1022825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
1023825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1024825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setStride(stride);
1025825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setPadding({padding[0].first, padding[1].first});
1026825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  layer->setName(node_def.name().c_str());
1027825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1028825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1029825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
1030825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it back!
1031bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
1032825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1033f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
1034825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1035825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1036825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1037825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1038825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1039825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertActivation(
1040bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
1041825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
1042825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
1043825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1044825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
1045825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
1046825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1047825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1048825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1049825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1050825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1051825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertScale(Converter& ctx,
1052bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                const tensorflow::NodeDef& node_def,
1053825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights> const& inputs,
1054825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights>* outputs) {
1055825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1056825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      !inputs.at(1).is_weights())
1057825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1058bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Only supports tensor op weight for now, at " + node_def.name());
1059bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Implement tensor binaryOp weight [channel wise] for now;
1060825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1061825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1062825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): handle NHWC/NCHW transpose;
1063825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights = inputs.at(1).weights();
1064825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights empty_weights(weights.type_);
1065825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1066825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1067825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1068bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Transpose NHWC
1069bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  auto data_format = attrs.get<string>("data_format");
1070825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
1071bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1072825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 1, 2});
1073825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it
1074825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1075f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
1076825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1077825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
1078825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
1079825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      weights, empty_weights, empty_weights);
1080825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1081825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1082825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (data_format == "NHWC") {
1083825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): transpose it back!
1084bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
1085825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1086f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "NCHW !!!!";
1087825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1088825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1089825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1090825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1091825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1092825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertConst(Converter& ctx,
1093bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                const tensorflow::NodeDef& node_def,
1094825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights> const& inputs,
1095825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights>* outputs) {
1096bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  const auto& weights_tensor = node_def.attr().at("value").tensor();
1097825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1098bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Get trt type & shape
1099825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1100cfa374cefe132be886c26a374c51454177c68868gracehoney  const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
1101825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1102bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Create shaped weights as output
1103825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::Tensor tensor;
1104825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!tensor.FromProto(weights_tensor))
1105bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
1106825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                        node_def.name());
1107825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1108825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights weights(dtype);
1109825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!weights_tensor.float_val().empty()) {
1110f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "SCALAR!!!" << node_def.name();
1111825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::Dims scalar_shape;
1112825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (tensor.dims() > 0) {
1113bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Dimensions: " << tensor.dims();
1114825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
1115bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                  GetTensorShape(tensor));
1116825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else {
1117bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      VLOG(2) << "Dimensions: " << tensor.dims();
1118825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scalar_shape.nbDims = 1;
1119825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scalar_shape.d[0] = 1;
1120825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
1121825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
1122825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        scalar_shape.d[i] = 0;
1123825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
1124825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
1125825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
1126825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                  scalar_shape);
1127825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1128825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else if (!weights_tensor.tensor_content().empty()) {
1129f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    VLOG(2) << "TENSOR!!!" << node_def.name();
1130bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const auto& content = weights_tensor.tensor_content();
1131cfa374cefe132be886c26a374c51454177c68868gracehoney
1132cfa374cefe132be886c26a374c51454177c68868gracehoney    std::vector<char> values;
1133cfa374cefe132be886c26a374c51454177c68868gracehoney    if (content.size() > 0) {
1134cfa374cefe132be886c26a374c51454177c68868gracehoney      const int dtype_size = tensorflow::DataTypeSize(dtype);
1135cfa374cefe132be886c26a374c51454177c68868gracehoney      CHECK_EQ(0, content.size() % dtype_size)
1136cfa374cefe132be886c26a374c51454177c68868gracehoney          << "Tensor content size (" << content.size()
1137cfa374cefe132be886c26a374c51454177c68868gracehoney          << ") is not a multiple of " << dtype_size;
1138cfa374cefe132be886c26a374c51454177c68868gracehoney      values.resize(content.size());
1139cfa374cefe132be886c26a374c51454177c68868gracehoney      port::CopyToArray(content, values.data());
1140cfa374cefe132be886c26a374c51454177c68868gracehoney    }
1141cfa374cefe132be886c26a374c51454177c68868gracehoney    weights =
1142cfa374cefe132be886c26a374c51454177c68868gracehoney        TRT_ShapedWeights(dtype, nullptr, GetTensorShape(tensor), &values);
1143825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1144825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1145bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Not supported constant type, at " + node_def.name());
1146825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1147bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Pass the output
1148825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(weights));
1149825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1150825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1151825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1152825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertIdentity(
1153bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    Converter& ctx, const tensorflow::NodeDef& node_def,
1154825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights> const& inputs,
1155825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    std::vector<TRT_TensorOrWeights>* outputs) {
1156825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(inputs.at(0));
1157825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1158825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1159825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1160825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertBinary(Converter& ctx,
1161bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const tensorflow::NodeDef& node_def,
1162825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights> const& inputs,
1163825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights>* outputs) {
1164825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2)
1165825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::FailedPrecondition(
1166825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Binary ops require two tensor input, at " + node_def.name());
1167825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1168825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
1169825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return ConstantFoldBinary(ctx, node_def, inputs, outputs);
1170825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1171825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
1172825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
1173825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                inputs.at(1).weights(), outputs);
1174825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1175825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
1176825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
1177825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                inputs.at(0).weights(), outputs);
1178825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1179825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
1180825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
1181825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                inputs.at(1).tensor(), outputs);
1182825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1183825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::errors::Unknown("Binary op input error, at " +
1184825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     node_def.name());
1185825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1186825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1187825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertUnary(Converter& ctx,
1188bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                const tensorflow::NodeDef& node_def,
1189825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights> const& inputs,
1190825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                std::vector<TRT_TensorOrWeights>* outputs) {
1191825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 1)
1192825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::FailedPrecondition(
1193825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Unary ops require single tensor input, at " + node_def.name());
1194825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1195825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.at(0).is_weights())
1196825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return ConstantFoldUnary(ctx, node_def, inputs, outputs);
1197825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  else if (inputs.at(0).is_tensor())
1198825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1199825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Unary op for tensor not supported, at " + node_def.name());
1200825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1201825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::errors::Unknown("Binary op input error, at " +
1202825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                     node_def.name());
1203825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1204825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1205825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertReduce(Converter& ctx,
1206bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                                 const tensorflow::NodeDef& node_def,
1207825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights> const& inputs,
1208825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 std::vector<TRT_TensorOrWeights>* outputs) {
1209825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1210825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      !inputs.at(1).is_weights())
1211825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1212825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Input expects tensor and weights, at" + node_def.name());
1213825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1214bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Implement tensor binaryOp weight [channel wise] for now;
1215825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1216825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims = tensor->getDimensions();
1217bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Restore implicit batch dimension
1218bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  int nb_dims = dims.nbDims + 1;
1219825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1220825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights index_list = inputs.at(1).weights();
1221825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1222825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1223bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // TODO(jie): handle data type.
1224bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Index type here is done through TF type, so I can leverage their
1225bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // EnumToDataType for my cast
1226825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto index_type = attrs.get<tensorflow::DataType>("Tidx");
1227825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1228825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Only expect to handle INT32 as attributes for now
1229825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (index_type != tensorflow::DataType::DT_INT32)
1230825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
1231825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto index_list_data =
1232cfa374cefe132be886c26a374c51454177c68868gracehoney      static_cast<int*>(const_cast<void*>(index_list.GetValues()));
1233825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1234bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Hack warning: have to fall back to pool layer since reduce is not in public
1235bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // TRT yet.
1236bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (nb_dims != 4)
1237825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1238825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "TRT only support reduce on 4 dimensional tensors, at" +
1239825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        node_def.name());
1240825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (index_list.count() > 2)
1241825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1242825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "TRT cannot support reduce on more than 2 dimensions, at" +
1243825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        node_def.name());
1244825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1245825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::set<int> idx_set;
1246bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // We cannot operate on Channel. permutation flag used to transpose tensor
1247825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  int permuted_index = -1;
1248825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (int i = 0; i < index_list.count(); i++) {
1249825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (index_list_data[i] == 0)
1250825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
1251825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                                 node_def.name());
1252825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (index_list_data[i] == 1) permuted_index = 1;
1253825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    idx_set.emplace(index_list_data[i]);
1254825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1255825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1256bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<int> permutation_order(nb_dims);
1257825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW pool_kernel;
1258825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (permuted_index == 1) {
1259bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (int i = 2; i < nb_dims; i++) {
1260825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      if (idx_set.count(i)) {
1261825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        permuted_index = i;
1262825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        break;
1263825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      }
1264825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1265bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
1266825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1267825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    permutation_order[permuted_index] = 1;
1268825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    permutation_order[1] = permuted_index;
1269825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1270bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Apply permutation before extracting dimension for pool_kernel
1271bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1272825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 permutation_order);
1273825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1274825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1275bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Apply permutation before extracting dimension for pool_kernel
1276825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
1277825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
1278825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1279825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor;
1280825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1281825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (node_def.op() == "Mean") {
1282825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::IPoolingLayer* layer =
1283825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
1284825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                  nvinfer1::PoolingType::kAVERAGE, pool_kernel);
1285825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    output_tensor = layer->getOutput(0);
1286825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  } else {
1287825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1288825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Op not supported " + node_def.op() + " , at " + node_def.name());
1289825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1290825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (permuted_index != -1) {
1291bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // Apply permutation before extracting dimension for pool_kernel
1292bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(
1293825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
1294825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1295825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1296825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1297825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1298825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPad(Converter& ctx,
1299bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney                              const tensorflow::NodeDef& node_def,
1300825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                              std::vector<TRT_TensorOrWeights> const& inputs,
1301825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                              std::vector<TRT_TensorOrWeights>* outputs) {
1302825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1303825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      !inputs.at(1).is_weights())
1304825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1305825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Input expects tensor and weights, at" + node_def.name());
1306825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1307bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Implement tensor binaryOp weight [channel wise] for now;
1308825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1309825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto dims = tensor->getDimensions();
1310bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Restore implicit batch dimension
1311bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  int nb_dims = dims.nbDims + 1;
1312825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1313825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TRT_ShapedWeights pads = inputs.at(1).weights();
1314825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1315825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  TFAttrs attrs(node_def);
1316bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Padding type here is done through TF type
1317825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  //   so I can leverage their EnumToDataType for my cast
1318825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
1319825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): handle data type conversion for TRT?
1320825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1321bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
1322825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1323825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Pad only supports explicit padding on 4 dimensional tensor, at " +
1324825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        node_def.name());
1325825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1326825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Only expect to handle INT32 as attributes for now
1327825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (padding_type != tensorflow::DataType::DT_INT32)
1328825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1329825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Tpaddings supports only DT_INT32");
1330cfa374cefe132be886c26a374c51454177c68868gracehoney  auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
1331825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1332825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<int32_t> pad_index;
1333bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  for (int i = 0; i < nb_dims; i++) {
1334825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
1335825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      pad_index.push_back(i);
1336825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1337825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1338bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // No padding at all, we should exit
1339825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index.size() == 0) {
1340825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    outputs->push_back(inputs.at(0));
1341825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::Status::OK();
1342825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1343825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1344bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Only supports padding on less than 2 axis GIE-2579
1345825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index.size() > 2)
1346825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1347825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Padding layer does not support padding on > 2");
1348825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1349bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Padding on batch dimension is not supported
1350825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index[0] == 0)
1351825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::InvalidArgument(
1352825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Padding layer does not support padding on batch dimension");
1353825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1354bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Not doing the legit thing here. ignoring padding on dim 1 and 3;
1355825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(jie): implement pad as uff parser
1356825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
1357825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Unimplemented(
1358825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        "Padding layer does not support padding on dimension 1 and 3 yet");
1359825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1360825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  bool legit_pad = true;
1361825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW pre_padding(0, 0);
1362825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::DimsHW post_padding(0, 0);
1363825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1364825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<int32_t> permuted_pad_index(pad_index);
1365825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (pad_index[0] == 1) {
1366825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    legit_pad = false;
1367bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1368825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                 {0, 3, 2, 1});
1369825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    permuted_pad_index[0] = 3;
1370825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1371825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1372825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (size_t i = 0; i < pad_index.size(); i++) {
1373825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int index = pad_index[i];
1374825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (permuted_pad_index[i] == 2) {
1375825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      pre_padding.h() = pad_data[index * 2];
1376825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      post_padding.h() = pad_data[index * 2 + 1];
1377825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    } else if (permuted_pad_index[i] == 3) {
1378825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      pre_padding.w() = pad_data[index * 2];
1379825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      post_padding.w() = pad_data[index * 2 + 1];
1380825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1381825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1382825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1383825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
1384825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
1385825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1386825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1387825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!legit_pad)
1388bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    output_tensor = ctx.TransposeTensor(
1389825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
1390825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1391825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  outputs->push_back(TRT_TensorOrWeights(output_tensor));
1392825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1393825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1394825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1395825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamavoid Converter::register_op_converters() {
1396825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // vgg_16 slim implementation
1397cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Placeholder"] = ConvertPlaceholder;
1398cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Conv2D"] = ConvertConv2D;
1399cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Relu"] = ConvertActivation;
1400cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["MaxPool"] = ConvertPool;
1401825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // This could be really handled as ConvertBinary
1402cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["BiasAdd"] = ConvertScale;
1403cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Const"] = ConvertConst;
1404cd63c718be123324b6c39e0f8fbe453319799746Jie  // op_registry_["MatMul"] = ConvertFullyConnected;  // Not used in vgg
1405825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(ben,jie): this is a temp hack.
1406cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Identity"] = ConvertIdentity;  // Identity should be removed
1407cd63c718be123324b6c39e0f8fbe453319799746Jie  // op_registry_["AvgPool"] = ConvertPool;
1408825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1409825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // resnet_50_v1 slim implementation
1410cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Add"] = ConvertBinary;
1411cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Mul"] = ConvertBinary;
1412cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Sub"] = ConvertBinary;
1413cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Rsqrt"] = ConvertUnary;
1414cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Mean"] = ConvertReduce;
1415cd63c718be123324b6c39e0f8fbe453319799746Jie  op_registry_["Pad"] = ConvertPad;
1416825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(ben,jie): Add more ops
1417825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1418825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1419825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}  // namespace
1420825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1421825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertSubGraphToTensorRTNodeDef(
1422825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
1423825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const std::vector<std::pair<int, int>>& input_inds,
1424825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
1425f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama    size_t max_workspace_size_bytes,
142668e17d497497119c24ad506dac4e34e127cf836cJie    const tensorflow::grappler::GraphProperties& graph_properties,
1427825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::NodeDef* trt_node) {
1428825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Visit nodes in reverse topological order and construct the TRT network.
1429825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1430825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Toposort
1431825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::Node*> order_vec;
1432825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::GetPostOrder(graph, &order_vec);
1433825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Select just the subgraph
1434825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::list<tensorflow::Node*> order;
1435825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (tensorflow::Node* node : order_vec) {
1436825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (subgraph_node_ids.count(node->id())) {
1437bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // We want topological order to contstruct the
1438bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      // network layer by layer
1439bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      order.push_front(node);
1440825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1441825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1442bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  // Topological order is needed to build TRT network
1443825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1444825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::tensorrt::Logger trt_logger;
1445825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1446825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
1447825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!trt_builder) {
1448825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Internal(
1449bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Failed to create TensorRT builder object");
1450825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1451825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1452825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto trt_network = infer_object(trt_builder->createNetwork());
1453825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  if (!trt_network) {
1454825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    return tensorflow::errors::Internal(
1455bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        "Failed to create TensorRT network object");
1456825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1457825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1458825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Build the network
1459825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  Converter converter(trt_network.get());
1460825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1461bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<string> input_names;
1462825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::DataType> input_dtypes;
1463825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (std::pair<int, int> const& input : input_inds) {
1464825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int node_id = input.first;
1465825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int output_idx = input.second;
1466825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::Node* node = graph.FindNodeId(node_id);
1467825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto node_name = node->name();
1468bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    input_names.push_back(node_name);  // Insert original node name without port
1469825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(jie): alternative :)
147068e17d497497119c24ad506dac4e34e127cf836cJie    if (!graph_properties.HasOutputProperties(node_name))
1471bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      return tensorflow::errors::Internal("Failed to find input node: " +
1472825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                          node_name);
1473825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
147468e17d497497119c24ad506dac4e34e127cf836cJie    auto op_info_vec = graph_properties.GetOutputProperties(node_name);
147568e17d497497119c24ad506dac4e34e127cf836cJie    if (static_cast<int>(op_info_vec.size()) < output_idx)
1476825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::Internal(
1477bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Accessing output index of: " + std::to_string(output_idx) +
1478bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          ", at node: " + node_name + " with output entry from shape_map: " +
147968e17d497497119c24ad506dac4e34e127cf836cJie          std::to_string(op_info_vec.size()));
1480825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
148168e17d497497119c24ad506dac4e34e127cf836cJie    auto op_info = op_info_vec.at(output_idx);
1482825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
148368e17d497497119c24ad506dac4e34e127cf836cJie    tensorflow::DataType tf_dtype = op_info.dtype();
1484825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    input_dtypes.push_back(tf_dtype);
1485825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1486825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
1487bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
1488825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1489bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
1490f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << ", at node: " << node_name
1491bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney            << " with output entry from shape_map: "
1492f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama            << std::to_string(op_info_vec.size());
149368e17d497497119c24ad506dac4e34e127cf836cJie
1494825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(ben,jie): update TRT input format/dimension
1495825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DimsCHW input_dim_psuedo_chw;
1496825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
1497825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
149868e17d497497119c24ad506dac4e34e127cf836cJie    for (int i = 1; i < op_info.shape().dim_size(); i++) {
1499f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama      VLOG(2) << "dimension: " << i
1500f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama              << " , size: " << op_info.shape().dim(i).size();
150168e17d497497119c24ad506dac4e34e127cf836cJie      input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
1502825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1503825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1504825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    // TODO(ben,jie): proper way to restore input tensor name?
1505825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto input_tensor_name = node_name;
1506825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (output_idx != 0)
1507825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      input_tensor_name = node_name + ":" + std::to_string(output_idx);
1508825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1509825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::ITensor* input_tensor = converter.network()->addInput(
1510825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
1511825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1512825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!input_tensor)
1513825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument(
1514825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          "Failed to create Input layer");
1515bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Input tensor name :" << input_tensor_name;
1516825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1517825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
1518825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::AlreadyExists(
1519bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney          "Output tensor already exists for op: " + input_tensor_name);
1520825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1521825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1522bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "Finished sorting";
1523825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1524825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (const tensorflow::Node* node : order) {
1525bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    const tensorflow::NodeDef& node_def = node->def();
1526bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
1527825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    TF_RETURN_IF_ERROR(converter.convert_node(node_def));
1528825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1529825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1530bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "Finished conversion";
1531825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1532825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Gather output metadata
1533bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  std::vector<string> output_names;
1534825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::DataType> output_dtypes;
1535825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (std::pair<int, int> const& output : output_inds) {
1536825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int node_id = output.first;
1537825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int output_idx = output.second;
1538825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::Node* node = graph.FindNodeId(node_id);
1539bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    string op_name = node->name();
1540bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    string tensor_name = op_name;
1541825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (output_idx != 0)
1542825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      tensor_name = tensor_name + ":" + std::to_string(output_idx);
1543bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(2) << "Output tensor name: " << tensor_name;
1544825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    output_names.push_back(tensor_name);
1545825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto tensor_or_weights = converter.get_tensor(tensor_name);
1546825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!tensor_or_weights.is_tensor()) {
1547825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::InvalidArgument(
1548825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama          "Output node is weights not tensor");
1549825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1550825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
1551825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    if (!tensor) {
1552825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama      return tensorflow::errors::NotFound("Output tensor not found: " +
1553825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                                          tensor_name);
1554825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    }
1555825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    converter.network()->markOutput(*tensor);
1556825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensorflow::DataType tf_dtype = node->output_type(output_idx);
1557825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    output_dtypes.push_back(tf_dtype);
1558825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
1559bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
1560825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    tensor->setType(trt_dtype);
1561825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1562825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1563bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(2) << "Finished output";
1564cd63c718be123324b6c39e0f8fbe453319799746Jie  // TODO(jie): static_id is not thread safe.
1565599eadc299ae680bfb569ace4278b2eb262ecc44Sami Kama  static int static_id = 0;
1566825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1567825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Build the engine
1568825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  trt_builder->setMaxBatchSize(max_batch_size);
1569f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
1570bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(0) << "Starting build engine " << static_id;
1571825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(ben,jie): half2 and int8 mode support
1572bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  string engine_plan_string;
1573825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  {
1574825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto trt_engine =
1575825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        infer_object(trt_builder->buildCudaEngine(*converter.network()));
1576bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(0) << "Built network";
1577825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    auto engine_plan = infer_object(trt_engine->serialize());
1578bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    VLOG(0) << "Serialized engine";
1579825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    const char* engine_plan_data =
1580825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama        static_cast<const char*>(engine_plan->data());
1581bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    engine_plan_string =
1582bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney        string(engine_plan_data, engine_plan_data + engine_plan->size());
1583825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1584825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1585bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(0) << "Finished engine";
1586825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1587825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // Build the TRT op
1588825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  // TODO(sami,ben,jie): proper naming!
1589825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  tensorflow::NodeDefBuilder op_builder(
1590bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney      tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
1591825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
1592825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  for (size_t i = 0; i < input_names.size(); ++i) {
1593825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    int output_idx = input_inds.at(i).second;
1594bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // We wired up the input here already, it is redundant to do it again in
1595bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney    // ConvertSubGraphToTensorRT(convert_graph.cc)
1596e01844e65e0dbd2682a894946bec7f072d36fa27Jie    auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
1597e01844e65e0dbd2682a894946bec7f072d36fa27Jie        input_names.at(i), output_idx, input_dtypes.at(i));
1598825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama    income_edges.push_back(incoming_edge);
1599825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  }
1600e01844e65e0dbd2682a894946bec7f072d36fa27Jie  tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
1601e01844e65e0dbd2682a894946bec7f072d36fa27Jie      income_edges);
1602825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  op_builder.Input(input_list);
1603825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1604bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney  VLOG(0) << "Finished op preparation";
1605825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1606825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  auto status = op_builder.Attr("serialized_engine", engine_plan_string)
1607825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Attr("input_nodes", input_names)
1608825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Attr("output_nodes", output_names)
1609825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Attr("OutT", output_dtypes)
1610825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama                    .Finalize(trt_node);
1611825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1612f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama  VLOG(0) << status.ToString() << " finished op building";
1613825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1614825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama  return tensorflow::Status::OK();
1615825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}
1616825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama
1617825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}  // namespace convert
1618825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}  // namespace tensorrt
16196908cc233c679b8fe61d99a30d3828362caf47beSami Kama}  // namespace tensorflow
1620ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie
1621e01844e65e0dbd2682a894946bec7f072d36fa27Jie#endif  // GOOGLE_TENSORRT
1622e01844e65e0dbd2682a894946bec7f072d36fa27Jie#endif  // GOOGLE_CUDA
1623