16908cc233c679b8fe61d99a30d3828362caf47beSami Kama/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 3825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaLicensed under the Apache License, Version 2.0 (the "License"); 4825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamayou may not use this file except in compliance with the License. 5825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaYou may obtain a copy of the License at 6825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 7825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama http://www.apache.org/licenses/LICENSE-2.0 8825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 9825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaUnless required by applicable law or agreed to in writing, software 10825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamadistributed under the License is distributed on an "AS IS" BASIS, 11825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami KamaSee the License for the specific language governing permissions and 13825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamalimitations under the License. 14825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama==============================================================================*/ 15825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 16d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" 17d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney 18825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <algorithm> 19825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <list> 20825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <map> 21825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <memory> 22825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <set> 23825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <unordered_map> 24825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <utility> 25825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include <vector> 26825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 27825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/framework/node_def_builder.h" 28cf4d066720bbf9c3c79aa62d9b1057939af5ff63gracehoney#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT 29bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/framework/types.h" 30825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/algorithm.h" 31825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/graph.h" 32825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/graph/graph_constructor.h" 33825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/lib/core/errors.h" 34825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/lib/core/status.h" 35bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/lib/strings/strcat.h" 36825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#include "tensorflow/core/platform/logging.h" 37bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/platform/tensor_coding.h" 38bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney#include "tensorflow/core/platform/types.h" 39825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 40ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#if GOOGLE_CUDA 41ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#if GOOGLE_TENSORRT 42ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie#include "tensorflow/contrib/tensorrt/log/trt_logger.h" 438e03944589542bd64559d68989bca4a4705eed93gracehoney#include "tensorrt/include/NvInfer.h" 448e03944589542bd64559d68989bca4a4705eed93gracehoney 45d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney// Check if the types are equal. Cast to int first so that failure log message 46d7b4fe4d4322a3fdab8a1dedb93d37a1f800a559gracehoney// would work! 47825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) 488e03944589542bd64559d68989bca4a4705eed93gracehoney 496908cc233c679b8fe61d99a30d3828362caf47beSami Kamanamespace tensorflow { 50825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace tensorrt { 51825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace convert { 52825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 53825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamanamespace { 54825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 55bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, 56bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney nvinfer1::DataType* trt_dtype) { 57825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (tf_dtype) { 58825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case tensorflow::DataType::DT_FLOAT: 59825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *trt_dtype = nvinfer1::DataType::kFLOAT; 60825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 61825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case tensorflow::DataType::DT_INT8: 62825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *trt_dtype = nvinfer1::DataType::kINT8; 63825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 64825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case tensorflow::DataType::DT_HALF: 65825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *trt_dtype = nvinfer1::DataType::kHALF; 66825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 67825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 68825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument("Unsupported data type"); 69825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 70825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 71825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 72825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 73bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { 74825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims dims; 75825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama dims.nbDims = tensor.dims(); 76825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int i = 0; i < dims.nbDims; i++) { 77825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama dims.d[i] = tensor.dim_size(i); 78825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 79825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return dims; 80825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 81825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 82bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyinline int64_t GetShapeSize(nvinfer1::Dims shape) { 83825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Returns total number of elements in shape 84825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int64_t count = 1; 85825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int d = 0; d < shape.nbDims; ++d) { 86825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama count *= shape.d[d]; 87825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 88825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return count; 89825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 90825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 91bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystatic std::vector<std::pair<int, int>> CreateSamePadding( 92bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, 93bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const std::vector<int64_t>& input_dims) { 94bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney std::vector<std::pair<int, int>> padding(input_dims.size()); 95bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? 96825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 97bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (size_t i = 0; i < input_dims.size(); ++i) { 98bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Formula to calculate the padding 99bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] - 100bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney input_dims[i]; 101825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama p = (p > 0) ? p : 0; 102825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 103bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Right precedence padding, like in TensorFlow 104825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int left = p / 2; 105825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int right = p - left; 106825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 107f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right 108bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney << "paras: " << input_dims[i] << ", " << stride.d[i] << ", " 109f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << "kernel: " << kernel.d[i]; 110825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama padding[i] = {left, right}; 111825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 112825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return padding; 113825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 114825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 115825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TRT_ShapedWeights { 116825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public: 117bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_ShapedWeights(tensorflow::DataType type, const void* values, 118b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie nvinfer1::Dims shape) 119b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) { 120bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Note: this->shape.type[] is not used 121bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney } 122bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 123bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney explicit TRT_ShapedWeights(tensorflow::DataType type) 124b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} 125bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 126cfa374cefe132be886c26a374c51454177c68868gracehoney TRT_ShapedWeights(const TRT_ShapedWeights& rhs) 127cfa374cefe132be886c26a374c51454177c68868gracehoney : shape_(rhs.shape_), 128cfa374cefe132be886c26a374c51454177c68868gracehoney type_(rhs.type_), 129cfa374cefe132be886c26a374c51454177c68868gracehoney values_(rhs.values_), 130cd63c718be123324b6c39e0f8fbe453319799746Jie empty_weight_flag_(rhs.empty_weight_flag_) {} 131bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 132825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int64_t count() const { 133825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int64_t c = 1; 134825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; 135825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return c; 136825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 137bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 138bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney nvinfer1::Weights GetWeightsForTRT() const { 139825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); 140bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TF_CHECK_OK(ConvertDType(type_, &trt_type)); 141cd63c718be123324b6c39e0f8fbe453319799746Jie if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0}; 142825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 143825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Note: this->shape.type[] is not used 144cfa374cefe132be886c26a374c51454177c68868gracehoney return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)}; 145cfa374cefe132be886c26a374c51454177c68868gracehoney } 146cfa374cefe132be886c26a374c51454177c68868gracehoney 147b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie const void* GetValues() const { return values_; } 148cfa374cefe132be886c26a374c51454177c68868gracehoney 149b4d59a8f82437b6da897c2c5fe773db7127efdd8Jie void SetValues(const void* values) { values_ = values; } 150bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 151825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama size_t size_bytes() const { 15275adab6104362d71ce28b0269bf31fd30471b1b6Jie int type_size = tensorflow::DataTypeSize(this->type_); 15375adab6104362d71ce28b0269bf31fd30471b1b6Jie return this->count() * type_size; 154825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 155bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 156bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Default converter 157bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney operator nvinfer1::Weights() const { return GetWeightsForTRT(); } 158bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 159bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney nvinfer1::Dims shape_; 160bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensorflow::DataType type_; 161cfa374cefe132be886c26a374c51454177c68868gracehoney 162cfa374cefe132be886c26a374c51454177c68868gracehoney private: 163bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const void* values_; 164cd63c718be123324b6c39e0f8fbe453319799746Jie bool empty_weight_flag_; 165825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}; 166825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 167825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TRT_TensorOrWeights { 168825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public: 169825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor) 170cd63c718be123324b6c39e0f8fbe453319799746Jie : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {} 171cfa374cefe132be886c26a374c51454177c68868gracehoney explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights) 172cd63c718be123324b6c39e0f8fbe453319799746Jie : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {} 173cfa374cefe132be886c26a374c51454177c68868gracehoney TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) 1741726dfc979de175eb093b3ee88907fbdc238ce79Jie : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {} 175bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney ~TRT_TensorOrWeights() {} 176bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 177cd63c718be123324b6c39e0f8fbe453319799746Jie bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; } 178cd63c718be123324b6c39e0f8fbe453319799746Jie bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; } 179bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 180825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* tensor() { 181cfa374cefe132be886c26a374c51454177c68868gracehoney CHECK_EQ(is_tensor(), true); 182cd63c718be123324b6c39e0f8fbe453319799746Jie return tensor_; 183825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 184cfa374cefe132be886c26a374c51454177c68868gracehoney const nvinfer1::ITensor* tensor() const { 185cfa374cefe132be886c26a374c51454177c68868gracehoney CHECK_EQ(is_tensor(), true); 186cd63c718be123324b6c39e0f8fbe453319799746Jie return tensor_; 187825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 188825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights& weights() { 189cfa374cefe132be886c26a374c51454177c68868gracehoney CHECK_EQ(is_weights(), true); 190cd63c718be123324b6c39e0f8fbe453319799746Jie return weights_; 191825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 192bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const TRT_ShapedWeights& weights() const { 193cfa374cefe132be886c26a374c51454177c68868gracehoney CHECK_EQ(is_weights(), true); 194cd63c718be123324b6c39e0f8fbe453319799746Jie return weights_; 195825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 196825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims shape() const { 197cfa374cefe132be886c26a374c51454177c68868gracehoney if (is_tensor()) { 198cfa374cefe132be886c26a374c51454177c68868gracehoney return tensor()->getDimensions(); 199825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 200cfa374cefe132be886c26a374c51454177c68868gracehoney return weights().shape_; 201825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 202825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 203825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 204bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney private: 205cd63c718be123324b6c39e0f8fbe453319799746Jie nvinfer1::ITensor* tensor_; 206cd63c718be123324b6c39e0f8fbe453319799746Jie TRT_ShapedWeights weights_; 207cd63c718be123324b6c39e0f8fbe453319799746Jie enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_; 208825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}; 209825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 210825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass TFAttrs { 211825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public: 212bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney explicit TFAttrs(const tensorflow::NodeDef& tf_node) { 213bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (const auto& attr : tf_node.attr()) { 214cd63c718be123324b6c39e0f8fbe453319799746Jie attrs_.insert({attr.first, &attr.second}); 215825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 216825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 217cd63c718be123324b6c39e0f8fbe453319799746Jie bool count(string key) const { return attrs_.count(key); } 218bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensorflow::AttrValue const* at(string key) const { 219cd63c718be123324b6c39e0f8fbe453319799746Jie if (!attrs_.count(key)) { 2208e03944589542bd64559d68989bca4a4705eed93gracehoney LOG(FATAL) << "Attribute not found: " << key; 221825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 222cd63c718be123324b6c39e0f8fbe453319799746Jie return attrs_.at(key); 223825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 224825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 225bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney T get(string key) const; 226825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 227bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney T get(string key, const T& default_value) const { 228cd63c718be123324b6c39e0f8fbe453319799746Jie return attrs_.count(key) ? this->get<T>(key) : default_value; 229825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 230bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 231bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney private: 232bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney typedef std::map<string, tensorflow::AttrValue const*> AttrMap; 233cd63c718be123324b6c39e0f8fbe453319799746Jie AttrMap attrs_; 234825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}; 235825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 236825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <> 237bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystring TFAttrs::get<string>(string key) const { 238825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return this->at(key)->s(); 239825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 240bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 241825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <> 242bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneystd::vector<int> TFAttrs::get<std::vector<int>>(string key) const { 243825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto attr = this->at(key)->list().i(); 244825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return std::vector<int>(attr.begin(), attr.end()); 245825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 246bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney 247825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <> 248bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneynvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const { 249825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto values = this->get<std::vector<int>>(key); 250825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims dims; 251825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama dims.nbDims = values.size(); 252825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::copy(values.begin(), values.end(), dims.d); 253825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Note: No dimension type information is included 254825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return dims; 255825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 2566908cc233c679b8fe61d99a30d3828362caf47beSami Kama 257825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <> 258bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneynvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const { 259825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); 260bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); 261825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return trt_dtype; 262825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 2636908cc233c679b8fe61d99a30d3828362caf47beSami Kama 264825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <> 265bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const { 266825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return this->at(key)->type(); 267825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 268825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 269825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <typename T> 270cfa374cefe132be886c26a374c51454177c68868gracehoneyvoid Reorder4(nvinfer1::DimsNCHW shape, const T* idata, 271825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsNCHW istrides, T* odata, 272825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsNCHW ostrides) { 273825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int n = 0; n < shape.n(); ++n) { 274825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int c = 0; c < shape.c(); ++c) { 275825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int h = 0; h < shape.h(); ++h) { 276825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int w = 0; w < shape.w(); ++w) { 277825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() + 278825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() + 279825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama h * istrides.h() + w * istrides.w()]; 280825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 281825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 282825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 283825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 284825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 285825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 286bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneyvoid ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, 287bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_ShapedWeights* oweights) { 288825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(iweights.type_, oweights->type_); 289825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); 290825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int r = iweights.shape_.d[0]; 291825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int s = iweights.shape_.d[1]; 292825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int c = iweights.shape_.d[2]; 293825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int k = iweights.shape_.d[3]; 294825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama oweights->shape_.d[0] = k; 295825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama oweights->shape_.d[1] = c; 296825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama oweights->shape_.d[2] = r; 297825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama oweights->shape_.d[3] = s; 298825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; 299825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; 300825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (iweights.type_) { 301825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case tensorflow::DataType::DT_FLOAT: 302cfa374cefe132be886c26a374c51454177c68868gracehoney Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()), 303cfa374cefe132be886c26a374c51454177c68868gracehoney istrides, 304cfa374cefe132be886c26a374c51454177c68868gracehoney static_cast<float*>(const_cast<void*>(oweights->GetValues())), 305cfa374cefe132be886c26a374c51454177c68868gracehoney ostrides); 306825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 307825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 308825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!"; 309825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 310825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 311825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 312825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamastruct InferDeleter { 313825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 314825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama void operator()(T* obj) const { 315825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (obj) { 316825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama obj->destroy(); 317825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 318825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 319825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}; 320825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 321825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatemplate <typename T> 322825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamainline std::shared_ptr<T> infer_object(T* obj) { 323825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return std::shared_ptr<T>(obj, InferDeleter()); 324825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 325825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 326825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// Logger for GIE info/warning/errors 327825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass Converter; 328825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 329825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamausing OpConverter = 330bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&, 331825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const&, 332825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>*)>; 333825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 334825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamaclass Converter { 335cd63c718be123324b6c39e0f8fbe453319799746Jie std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_; 336cd63c718be123324b6c39e0f8fbe453319799746Jie std::unordered_map<string, OpConverter> op_registry_; 337cd63c718be123324b6c39e0f8fbe453319799746Jie nvinfer1::INetworkDefinition* trt_network_; 338cd63c718be123324b6c39e0f8fbe453319799746Jie std::list<std::vector<uint8_t>> temp_bufs_; 339825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 340825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama void register_op_converters(); 341825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 342825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> get_inputs( 343bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def) { 344825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> inputs; 345bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (const auto& input_name : node_def.input()) { 346bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Retrieve input: " << input_name; 347cd63c718be123324b6c39e0f8fbe453319799746Jie inputs.push_back(trt_tensors_.at(input_name)); 348825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 349825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return inputs; 350825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 351825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 352825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama public: 353825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama explicit Converter(nvinfer1::INetworkDefinition* trt_network) 354cd63c718be123324b6c39e0f8fbe453319799746Jie : trt_network_(trt_network) { 355825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama this->register_op_converters(); 356825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 357825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 358825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, 359825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims shape) { 360825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights(type, nullptr, shape); 36175adab6104362d71ce28b0269bf31fd30471b1b6Jie // TODO(jie): check weights size_bytes. 0 means type error 362cd63c718be123324b6c39e0f8fbe453319799746Jie temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes())); 363cd63c718be123324b6c39e0f8fbe453319799746Jie weights.SetValues(temp_bufs_.back().data()); 364825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return weights; 365825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 366825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 367bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { 368825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return this->get_temp_weights(weights.type_, weights.shape_); 369825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 370825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 371bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { 372825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def); 373bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney string op = node_def.op(); 374cd63c718be123324b6c39e0f8fbe453319799746Jie if (!op_registry_.count(op)) { 375825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 376bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "No converter registered for op: " + op); 377825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 378cd63c718be123324b6c39e0f8fbe453319799746Jie OpConverter op_converter = op_registry_.at(op); 379825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> outputs; 380825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); 381825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (size_t i = 0; i < outputs.size(); ++i) { 382825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_TensorOrWeights output = outputs.at(i); 383825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): tf protobuf seems to be omitting the :0 suffix 384bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney string output_name = node_def.name(); 385825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (i != 0) output_name = output_name + ":" + std::to_string(i); 386825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (output.is_tensor()) { 387825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama output.tensor()->setName(output_name.c_str()); 388825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 389bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Write out tensor: " << output_name; 390cd63c718be123324b6c39e0f8fbe453319799746Jie if (!trt_tensors_.insert({output_name, output}).second) { 391825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::AlreadyExists( 392bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Output tensor already exists for op: " + op); 393825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 394825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 395825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 396825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 397825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 398cd63c718be123324b6c39e0f8fbe453319799746Jie nvinfer1::INetworkDefinition* network() { return trt_network_; } 399825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 400bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_TensorOrWeights get_tensor(string name) { 401cd63c718be123324b6c39e0f8fbe453319799746Jie if (!trt_tensors_.count(name)) { 402825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return TRT_TensorOrWeights(nullptr); 403825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 404cd63c718be123324b6c39e0f8fbe453319799746Jie return trt_tensors_.at(name); 405825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 406825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 407bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) { 408cd63c718be123324b6c39e0f8fbe453319799746Jie return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second; 409825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 410825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 411bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor, 412825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<int> order) { 413825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto dims = input_tensor->getDimensions(); 414825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 415825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): change the return to status and properly exit 416825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (order.size() - 1 != size_t(dims.nbDims)) 417bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney LOG(ERROR) << "Dimension does not match, fail gracefully"; 418825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 419825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); 420825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Permutation permutation; 421825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int32_t i = 0; i < dims.nbDims; ++i) { 422825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama permutation.order[i] = order[i + 1] - 1; 423825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 424825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setFirstTranspose(permutation); 425825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 426bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney nvinfer1::Dims reshape_dims; 427bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney reshape_dims.nbDims = dims.nbDims; 428bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { 429bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney reshape_dims.d[i] = 0; 430bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney reshape_dims.type[i] = dims.type[i]; 431825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 432bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney layer->setReshapeDimensions(reshape_dims); 433825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return layer->getOutput(0); 434825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 435825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}; 436825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 4376908cc233c679b8fe61d99a30d3828362caf47beSami Kama// **************************************************************************** 438e01844e65e0dbd2682a894946bec7f072d36fa27Jie// Constant folding functions 439e01844e65e0dbd2682a894946bec7f072d36fa27Jie// TODO(jie): once optimizer kicks in, we should have done constant folding 440e01844e65e0dbd2682a894946bec7f072d36fa27Jie// there. 4416908cc233c679b8fe61d99a30d3828362caf47beSami Kama//*****************************************************************************/ 442825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamastruct LambdaFactory { 443825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; 444825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama OP_CATEGORY op; 445825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 446825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 447825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::function<T(T)> unary() { 448825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (op) { 449825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::RSQRT: { 450f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "RSQRT GETS DONE"; 451825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [](T t) -> T { return 1.0 / std::sqrt(t); }; 452825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 453825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::NEG: 454825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [](T t) -> T { return -t; }; 455825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 456bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Not supported op for unary: " << static_cast<int>(op); 457825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return nullptr; 458825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 459825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 460825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 461825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 462825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::function<T(T, T)> binary() { 463825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (op) { 464825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::ADD: 465825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [](T l, T r) -> T { return l + r; }; 466825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::SUB: 467825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [](T l, T r) -> T { return l - r; }; 468825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::MUL: 469825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [](T l, T r) -> T { return l * r; }; 470825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 471bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op); 472825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 473825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [](T l, T r) -> T { 474825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LOG(FATAL) << "Unsupported op type "; 475825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return l; 476825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 477825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 478825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 479825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 480825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::function<T(T)> broadcast_r(T val) { 481f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 482825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (op) { 483825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::ADD: 484825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 485f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 486825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return l + val; 487825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 488bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Return [val](T l)-> T {return l+val;}; 489825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::SUB: 490825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 491f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 492825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return l - val; 493825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 494825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::MUL: 495825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 496f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 497825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return l * val; 498825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 499825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 500bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op); 501825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 502825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 503825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LOG(FATAL) << "Unsupported op type "; 504825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return l; 505825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 506825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 507825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 508825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama template <typename T> 509825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::function<T(T)> broadcast_l(T val) { 510f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 511825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (op) { 512825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::ADD: 513825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 514f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 515825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return val + l; 516825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 517825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::SUB: 518825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 519f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 520825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return val - l; 521825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 522825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case OP_CATEGORY::MUL: 523825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 524f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "LAMBDA VAL : " << val; 525825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return val * l; 526825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 527825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 528bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op); 529825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 530825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return [val](T l) -> T { 531825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LOG(FATAL) << "Unsupported op type "; 532825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return l; 533825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama }; 534825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 535825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama}; 536825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 537bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, 538825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights* oweights, 539825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LambdaFactory unary_op) { 540825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(iweights.type_, oweights->type_); 541825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (iweights.type_) { 542825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case tensorflow::DataType::DT_FLOAT: { 543cfa374cefe132be886c26a374c51454177c68868gracehoney auto inp = static_cast<float const*>(iweights.GetValues()); 544cfa374cefe132be886c26a374c51454177c68868gracehoney auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues())); 545825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>()); 546825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 547825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 548825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 549bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney return tensorflow::errors::Unimplemented( 550bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Data type not supported: " + 551bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensorflow::DataTypeString(iweights.type_)); 552825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 553825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 554825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 555825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 556bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoneytensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, 557bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const TRT_ShapedWeights& iweights_r, 558825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights* oweights, 559825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LambdaFactory binary_op) { 560bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Assume iweights_l.type == iweight_r.type 561825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(iweights_l.type_, oweights->type_); 562825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(iweights_r.type_, oweights->type_); 563f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "SANITY CHECK!"; 564825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 565825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama switch (iweights_l.type_) { 566825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama case tensorflow::DataType::DT_FLOAT: { 567cfa374cefe132be886c26a374c51454177c68868gracehoney auto inp_l = static_cast<const float*>(iweights_l.GetValues()); 568cfa374cefe132be886c26a374c51454177c68868gracehoney auto inp_r = static_cast<const float*>(iweights_r.GetValues()); 569cfa374cefe132be886c26a374c51454177c68868gracehoney auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues())); 570825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 571825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (iweights_l.count() != iweights_r.count()) { 572bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // We only supports broadcast of RankZero 573825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (iweights_l.count() == 1) { 574f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "I bet it is not working!" << (*inp_l); 575825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::transform(inp_r, inp_r + iweights_r.count(), oup, 576825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op.broadcast_l<float>(*inp_l)); 577825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (iweights_r.count() == 1) { 578f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "I bet it is not working!" << (*inp_r); 579825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::transform(inp_l, inp_l + iweights_l.count(), oup, 580825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op.broadcast_r<float>(*inp_r)); 581825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 582825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 583825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Binary op with non-rankZero broadcast not supported"); 584825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 585825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 586825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup, 587825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op.binary<float>()); 588825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 589825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 590825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 591825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama default: 592bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney return tensorflow::errors::Unimplemented( 593bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Data type not supported: " + 594bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensorflow::DataTypeString(iweights_l.type_)); 595825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 596825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 597825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 598825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 599825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 600825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConstantFoldUnary( 601bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 602825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 603825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 604825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights_input = inputs.at(0).weights(); 605825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 606bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Allocate output weights 607825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); 608825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 609825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // FIXME assume type matches input weights 610bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Get trt type & shape 611bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Maybe this part has to be moved into the block of rsqrt later 612bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Check type consistency 613825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(weights_input.type_, 614825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs(node_def).get<tensorflow::DataType>("T")); 615825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 616825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Maybe I should do a switch 617825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LambdaFactory unary_op; 618825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (node_def.op() == "Rsqrt") { 619bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Compute rsqrt 620825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; 621825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto ret = UnaryCompute(weights_input, &weights_output, unary_op); 622bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // PAss the output 623825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (ret == tensorflow::Status::OK()) { 624825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(weights_output)); 625825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 626825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return ret; 627825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 628825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented("Binary op not supported: " + 629825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.op()); 630825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 631825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 632825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 633825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// TODO(jie,ben) broadcast is needed yet not implemented 634825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// Let's get the simple stuff working first. Maybe we should fall bakc to TF 635825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama// approach for constant folding 636825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConstantFoldBinary( 637bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 638825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 639825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 640825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); 641825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); 642825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 643bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Check type consistency 644825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ(weights_input_l.type_, weights_input_r.type_); 645825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 646825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) 647825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 648825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Binary op implicit broadcast not supported: " + node_def.op()); 649825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 650825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): constant fold should really fall back to TF. 651bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney int nb_dims = weights_input_l.shape_.nbDims; 652825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims output_shape; 653bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney output_shape.nbDims = nb_dims; 654bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "nb_dims: " << nb_dims 655bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney << ", the other: " << weights_input_r.shape_.nbDims; 656bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (int i = 0; i < nb_dims; i++) { 657825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { 658825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama output_shape.d[i] = weights_input_l.shape_.d[i]; 659825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (weights_input_l.shape_.d[i] == 1 || 660825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama weights_input_r.shape_.d[i] == 1) { 661825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama output_shape.d[i] = 662825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); 663825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 664825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 665825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Binary op with incompatible shape at, " + node_def.op()); 666825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 667f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "left: " << weights_input_l.shape_.d[i] 668f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << "right: " << weights_input_r.shape_.d[i] 669f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << "output: " << output_shape.d[i]; 670825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 671825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 672825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // FIXME assume type matches input weights 673bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Get trt type & shape 674825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 675bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Maybe this part has to be moved into the block of rsqrt later 676825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T"); 677825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 678bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Allocate output weights 679825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); 680825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 681825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Maybe I should do a switch 682825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LambdaFactory binary_op; 683825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (node_def.op() == "Sub") { 684825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op.op = LambdaFactory::OP_CATEGORY::SUB; 685825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (node_def.op() == "Mul") { 686825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op.op = LambdaFactory::OP_CATEGORY::MUL; 687825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (node_def.op() == "Add") { 688825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op.op = LambdaFactory::OP_CATEGORY::ADD; 689825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 690825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented("Binary op not supported: " + 691825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.op()); 692825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 693825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, 694825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama binary_op); 695825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 696bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Pass the output 697825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (ret == tensorflow::Status::OK()) { 698825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(weights_output)); 699825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 700825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 701825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return ret; 702825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 703825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 704bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney// TODO(jie): broadcast is needed yet not implemented. 705bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney// Only implemented channel wise for the time being 706825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status BinaryTensorOpWeight( 707bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 708825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, 709825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 710825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // FIXME assume type matches input weights 711bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Get trt type & shape 712bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Maybe this part has to be moved into the block of rsqrt later 713825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 714bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Check type consistency 715825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T"); 716bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages 717825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType ttype; 718bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TF_CHECK_OK(ConvertDType(weights.type_, &ttype)); 719bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message 720825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 721bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Check scale mode 722825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto dims_w = weights.shape_; 723825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto dims_t = tensor->getDimensions(); 724825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 725bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Default to channel-wise 726825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; 727825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 728825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (weights.count() == 1) { 729f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "UNIFORM"; 730825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scale_mode = nvinfer1::ScaleMode::kUNIFORM; 731825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 732bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // No broadcasting on Batch dimension; 733e01844e65e0dbd2682a894946bec7f072d36fa27Jie assert(dims_w.d[0] == 1); 734825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 735bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Broadcasting on Channel dimension only allowed in kUNIFORM 736e01844e65e0dbd2682a894946bec7f072d36fa27Jie assert(dims_w.d[1] == dims_t.d[0]); 737e01844e65e0dbd2682a894946bec7f072d36fa27Jie assert(dims_w.nbDims == dims_t.nbDims); 738825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 739bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Default is element; 740e01844e65e0dbd2682a894946bec7f072d36fa27Jie for (int i = 2; i < dims_w.nbDims; i++) { 741e01844e65e0dbd2682a894946bec7f072d36fa27Jie if (dims_w.d[i] != dims_t.d[i - 1]) { 742825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scale_mode = nvinfer1::ScaleMode::kCHANNEL; 743825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 744825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 745825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 746825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) { 747825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; 748e01844e65e0dbd2682a894946bec7f072d36fa27Jie for (int i = 2; i < dims_w.nbDims; i++) { 749e01844e65e0dbd2682a894946bec7f072d36fa27Jie if (dims_w.d[i] != 1) 750825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 751e01844e65e0dbd2682a894946bec7f072d36fa27Jie "Weight shape not compatible at, " + node_def.name()); 752825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 753825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 754825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 755825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 756bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Prepare weights 757bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_ShapedWeights shift_weights(weights.type_); 758bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_ShapedWeights scale_weights(weights.type_); 759bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TRT_ShapedWeights power_weights(weights.type_); 760825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 761825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Maybe I should do a switch 762825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (node_def.op() == "Sub") { 763825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); 764825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama LambdaFactory unary_op; 765825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama unary_op.op = LambdaFactory::OP_CATEGORY::NEG; 766bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); 767bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney shift_weights = neg_weights; 768825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (node_def.op() == "Mul") { 769bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney scale_weights = weights; 770825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (node_def.op() == "Add") { 771bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney shift_weights = weights; 772825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 773825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented("Binary op not supported: " + 774825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.op()); 775825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 776825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 777825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IScaleLayer* layer = ctx.network()->addScale( 778bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights, 779bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney scale_weights, power_weights); 780825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 781825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 782825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 783bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Pass the output 784825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 785825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 786825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 787825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 788825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status BinaryTensorOpTensor( 789bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 790825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, 791825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 792bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{ 793bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney {"Add", nvinfer1::ElementWiseOperation::kSUM}, 794bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney {"Mul", nvinfer1::ElementWiseOperation::kPROD}, 795bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // {"max", nvinfer1::ElementWiseOperation::kMAX}, 796bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // {"min", nvinfer1::ElementWiseOperation::kMIN}, 797bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney {"Sub", nvinfer1::ElementWiseOperation::kSUB}, 798bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney {"Div", nvinfer1::ElementWiseOperation::kDIV}, 799bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney }; 800825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 801825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // FIXME assume type matches input weights 802bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Get trt type & shape 803825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 804bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Maybe this part has to be moved into the block of rsqrt later 805825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T"); 806825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 807bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Check type consistency 808825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ_TYPE(tensor_l->getType(), dtype); 809825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama CHECK_EQ_TYPE(tensor_r->getType(), dtype); 810825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto op_pair = ops.find(node_def.op()); 811825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (op_pair == ops.end()) 8121726dfc979de175eb093b3ee88907fbdc238ce79Jie return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + 8131726dfc979de175eb093b3ee88907fbdc238ce79Jie " not supported at: " + 8141726dfc979de175eb093b3ee88907fbdc238ce79Jie node_def.name()); 815825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 816825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( 817825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor_l), 818825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second); 819825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 820825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 821825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 822bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Pass the output 823825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 824825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 825825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 826825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 827825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPlaceholder( 828bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 829825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 830825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 831f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "Placeholder should have been replace already"; 832bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney return tensorflow::errors::Unimplemented(", cannot convert Placeholder op"); 833825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // OK this make sense since we are supposed to replace it with input 834825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 835825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype"); 836825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape"); 837825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 838825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama dims.nbDims--; 839825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; 840825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 841825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output = 842825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama ctx.network()->addInput(node_def.name().c_str(), dtype, dims); 843825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!output) { 844825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument("Failed to create Input layer"); 845825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 846825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output)); 847825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 848825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 849825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 850825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertConv2D(Converter& ctx, 851bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 852bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const std::vector<TRT_TensorOrWeights>& inputs, 853825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 854825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 855825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): handle NHWC/NCHW transpose; 856825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); 857825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); 858bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney ReorderRSCKToKCRS(weights_rsck, &weights); 859825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights biases(weights.type_); 860825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int noutput = weights.shape_.d[0]; 861825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW kernel_size; 862825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama kernel_size.h() = weights.shape_.d[2]; 863825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama kernel_size.w() = weights.shape_.d[3]; 864825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 865825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 866825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int h_index = 2; 867825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int w_index = 3; 868bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney auto data_format = attrs.get<string>("data_format"); 869825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (data_format == "NHWC") { 870bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 871825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama {0, 3, 1, 2}); 872825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama h_index = 1; 873825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama w_index = 2; 874825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): transpose it 875825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 876e01844e65e0dbd2682a894946bec7f072d36fa27Jie 877825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): stride. (NHWC/NCHW) 878825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto tf_stride = attrs.get<std::vector<int>>("strides"); 879825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); 880825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 881825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto tensor_dim = tensor->getDimensions(); 882825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<std::pair<int, int>> padding; 883825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): padding. 884bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney if (attrs.get<string>("padding") == "SAME") { 885825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // This is NCHW tensor with no batch dimension. 886825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // 1 -> h 887825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // 2 -> w 888bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney padding = CreateSamePadding( 889e01844e65e0dbd2682a894946bec7f072d36fa27Jie stride, kernel_size, 890e01844e65e0dbd2682a894946bec7f072d36fa27Jie {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])}); 891825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 892825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama padding = {{0, 0}, {0, 0}}; 893825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 894825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 895825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (padding[0].first != padding[0].second || 896825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama padding[1].first != padding[1].second) { 897825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): handle asymmetric padding 898bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second 899f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << padding[1].first << padding[1].second; 90024e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie 90124e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie auto dim_before = tensor->getDimensions(); 902f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1] 903f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << dim_before.d[2] << ", " << dim_before.d[3]; 904bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney auto pad_layer = ctx.network()->addPadding( 905825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor), 90624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie nvinfer1::DimsHW(padding[0].first, padding[1].first), 90724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie nvinfer1::DimsHW(padding[0].second, padding[1].second)); 90824e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie padding = {{0, 0}, {0, 0}}; 909bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = pad_layer->getOutput(0); 91024e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie auto dim_after = tensor->getDimensions(); 911f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1] 912f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << dim_after.d[2] << ", " << dim_after.d[3]; 913825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 914825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 915825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IConvolutionLayer* layer = 916825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor), 917825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama noutput, kernel_size, weights, biases); 918825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 919825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setStride(stride); 920825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setPadding({padding[0].first, padding[1].first}); 921825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setName(node_def.name().c_str()); 922825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 923825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 92424e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie auto dim_after = output_tensor->getDimensions(); 925f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] 926f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << dim_after.d[2] << ", " << dim_after.d[3]; 92724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie 928825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (data_format == "NHWC") { 929825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): transpose it back! 930bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); 931825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 932f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "NCHW !!!!"; 933825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 934825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 935825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 936825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 937825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 938825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPool(Converter& ctx, 939bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 940825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 941825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 942825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 943825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 944825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 945825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int h_index = 2; 946825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int w_index = 3; 947bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney auto data_format = attrs.get<string>("data_format"); 948825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (data_format == "NHWC") { 949825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama h_index = 1; 950825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama w_index = 2; 951bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 952825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama {0, 3, 1, 2}); 953825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 954f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "NCHW !!!!"; 955825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 956825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::PoolingType type; 957825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): support other pooling type 958825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (node_def.op() == "MaxPool") 959825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama type = nvinfer1::PoolingType::kMAX; 960825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama else 961bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney return tensorflow::errors::Unimplemented("Only supports Max pool"); 962825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 963825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): NCHW 964825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto tf_stride = attrs.get<std::vector<int>>("strides"); 965825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); 966825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 967825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto tf_kernel = attrs.get<std::vector<int>>("ksize"); 968825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); 969825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 970825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto tensor_dim = tensor->getDimensions(); 971825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<std::pair<int, int>> padding; 972825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): padding. 973bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney if (attrs.get<string>("padding") == "SAME") { 974825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // This is NCHW tensor with no batch dimension. 975825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // 1 -> h 976825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // 2 -> w 977bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney padding = CreateSamePadding( 978825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama stride, ksize, 979825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])}); 980bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney } else if (attrs.get<string>("padding") == "VALID") { 981825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // No padding for valid padding here 982bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "No padding added for VALID padding in pool" << node_def.name(); 983825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama padding = {{0, 0}, {0, 0}}; 984825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 985825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 986825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Current MaxPool cannot support padding other than SAME"); 987825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 988825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 989825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (padding[0].first != padding[0].second || 990825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama padding[1].first != padding[1].second) { 991825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): handle asymmetric padding 992bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second 993f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << padding[1].first << padding[1].second; 994bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney auto pad_layer = ctx.network()->addPadding( 995825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor), 99624e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie nvinfer1::DimsHW(padding[0].first, padding[1].first), 99724e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie nvinfer1::DimsHW(padding[0].second, padding[1].second)); 99824e17d8e2d5adfc2fc8b6fa94b7590006b4d21a9Jie padding = {{0, 0}, {0, 0}}; 999bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = pad_layer->getOutput(0); 1000825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1001825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1002825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( 1003825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor), type, ksize); 1004825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1005825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setStride(stride); 1006825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setPadding({padding[0].first, padding[1].first}); 1007825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama layer->setName(node_def.name().c_str()); 1008825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1009825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1010825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (data_format == "NHWC") { 1011825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): transpose it back! 1012bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); 1013825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 1014f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "NCHW !!!!"; 1015825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1016825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1017825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1018825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1019825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1020825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertActivation( 1021bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 1022825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1023825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1024825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1025825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( 1026825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU); 1027825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1028825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1029825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1030825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1031825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1032825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertScale(Converter& ctx, 1033bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 1034825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1035825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1036825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.size() != 2 || !inputs.at(0).is_tensor() || 1037825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama !inputs.at(1).is_weights()) 1038825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 1039bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Only supports tensor op weight for now, at " + node_def.name()); 1040bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Implement tensor binaryOp weight [channel wise] for now; 1041825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1042825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1043825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): handle NHWC/NCHW transpose; 1044825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights = inputs.at(1).weights(); 1045825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights empty_weights(weights.type_); 1046825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1047825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 1048825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1049bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Transpose NHWC 1050bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney auto data_format = attrs.get<string>("data_format"); 1051825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (data_format == "NHWC") { 1052bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 1053825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama {0, 3, 1, 2}); 1054825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): transpose it 1055825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 1056f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "NCHW !!!!"; 1057825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1058825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IScaleLayer* layer = ctx.network()->addScale( 1059825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL, 1060825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama weights, empty_weights, empty_weights); 1061825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1062825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1063825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (data_format == "NHWC") { 1064825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): transpose it back! 1065bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); 1066825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 1067f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "NCHW !!!!"; 1068825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1069825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1070825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1071825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1072825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1073825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertConst(Converter& ctx, 1074bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 1075825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1076825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1077bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const auto& weights_tensor = node_def.attr().at("value").tensor(); 1078825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1079bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Get trt type & shape 1080825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 1081cfa374cefe132be886c26a374c51454177c68868gracehoney const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype"); 1082825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1083bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Create shaped weights as output 1084825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::Tensor tensor; 1085825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!tensor.FromProto(weights_tensor)) 1086bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney return tensorflow::errors::Internal("Cannot parse weight tensor proto: " + 1087825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1088825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1089825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights weights(dtype); 1090825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!weights_tensor.float_val().empty()) { 1091f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "SCALAR!!!" << node_def.name(); 1092825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::Dims scalar_shape; 1093825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (tensor.dims() > 0) { 1094bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Dimensions: " << tensor.dims(); 1095825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), 1096bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney GetTensorShape(tensor)); 1097825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 1098bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Dimensions: " << tensor.dims(); 1099825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scalar_shape.nbDims = 1; 1100825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scalar_shape.d[0] = 1; 1101825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; 1102825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { 1103825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scalar_shape.d[i] = 0; 1104825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; 1105825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1106825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), 1107825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama scalar_shape); 1108825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1109825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (!weights_tensor.tensor_content().empty()) { 1110f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "TENSOR!!!" << node_def.name(); 1111bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const auto& content = weights_tensor.tensor_content(); 1112e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie 1113e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor)); 1114e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie if (content.size() > 0) { 1115e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie const int dtype_size = tensorflow::DataTypeSize(dtype); 1116e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie CHECK_EQ(0, content.size() % dtype_size) 1117e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie << "Tensor content size (" << content.size() 1118e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie << ") is not a multiple of " << dtype_size; 11191726dfc979de175eb093b3ee88907fbdc238ce79Jie port::CopyToArray( 11201726dfc979de175eb093b3ee88907fbdc238ce79Jie content, static_cast<char*>(const_cast<void*>(weights.GetValues()))); 1121e65a6202dbfd034233a9a0b453461f1e3b7fce8dJie } 1122825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 1123825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 1124bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Not supported constant type, at " + node_def.name()); 1125825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1126bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Pass the output 1127825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(weights)); 1128825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1129825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1130825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1131825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertIdentity( 1132bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney Converter& ctx, const tensorflow::NodeDef& node_def, 1133825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1134825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1135825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(inputs.at(0)); 1136825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1137825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1138825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1139825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertBinary(Converter& ctx, 1140bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 1141825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1142825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1143825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.size() != 2) 1144825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::FailedPrecondition( 1145825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Binary ops require two tensor input, at " + node_def.name()); 1146825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1147825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) 1148825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return ConstantFoldBinary(ctx, node_def, inputs, outputs); 1149825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1150825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) 1151825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), 1152825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama inputs.at(1).weights(), outputs); 1153825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1154825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) 1155825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), 1156825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama inputs.at(0).weights(), outputs); 1157825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1158825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) 1159825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), 1160825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama inputs.at(1).tensor(), outputs); 1161825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1162825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unknown("Binary op input error, at " + 1163825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1164825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1165825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1166825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertUnary(Converter& ctx, 1167bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 1168825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1169825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1170825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.size() != 1) 1171825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::FailedPrecondition( 1172825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Unary ops require single tensor input, at " + node_def.name()); 1173825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1174825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.at(0).is_weights()) 1175825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return ConstantFoldUnary(ctx, node_def, inputs, outputs); 1176825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama else if (inputs.at(0).is_tensor()) 1177825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 1178825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Unary op for tensor not supported, at " + node_def.name()); 1179825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1180825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unknown("Binary op input error, at " + 1181825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1182825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1183825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1184825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertReduce(Converter& ctx, 1185bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 1186825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1187825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1188825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.size() != 2 || !inputs.at(0).is_tensor() || 1189825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama !inputs.at(1).is_weights()) 1190825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1191825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Input expects tensor and weights, at" + node_def.name()); 1192825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1193bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Implement tensor binaryOp weight [channel wise] for now; 1194825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1195825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto dims = tensor->getDimensions(); 1196bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Restore implicit batch dimension 1197bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney int nb_dims = dims.nbDims + 1; 1198825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1199825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights index_list = inputs.at(1).weights(); 1200825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1201825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 1202bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // TODO(jie): handle data type. 1203bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Index type here is done through TF type, so I can leverage their 1204bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // EnumToDataType for my cast 1205825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto index_type = attrs.get<tensorflow::DataType>("Tidx"); 1206825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1207825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Only expect to handle INT32 as attributes for now 1208825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (index_type != tensorflow::DataType::DT_INT32) 1209825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); 1210825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto index_list_data = 1211cfa374cefe132be886c26a374c51454177c68868gracehoney static_cast<int*>(const_cast<void*>(index_list.GetValues())); 1212825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1213bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Hack warning: have to fall back to pool layer since reduce is not in public 1214bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // TRT yet. 1215bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney if (nb_dims != 4) 1216825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1217825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "TRT only support reduce on 4 dimensional tensors, at" + 1218825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1219825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (index_list.count() > 2) 1220825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1221825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "TRT cannot support reduce on more than 2 dimensions, at" + 1222825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1223825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1224825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::set<int> idx_set; 1225bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // We cannot operate on Channel. permutation flag used to transpose tensor 1226825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int permuted_index = -1; 1227825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int i = 0; i < index_list.count(); i++) { 1228825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (index_list_data[i] == 0) 1229825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + 1230825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1231825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (index_list_data[i] == 1) permuted_index = 1; 1232825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama idx_set.emplace(index_list_data[i]); 1233825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1234825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1235bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney std::vector<int> permutation_order(nb_dims); 1236825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW pool_kernel; 1237825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (permuted_index == 1) { 1238bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (int i = 2; i < nb_dims; i++) { 1239825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (idx_set.count(i)) { 1240825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama permuted_index = i; 1241825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama break; 1242825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1243825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1244bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (int i = 0; i < nb_dims; i++) permutation_order[i] = i; 1245825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1246825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama permutation_order[permuted_index] = 1; 1247825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama permutation_order[1] = permuted_index; 1248825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1249bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Apply permutation before extracting dimension for pool_kernel 1250bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 1251825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama permutation_order); 1252825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1253825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1254bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Apply permutation before extracting dimension for pool_kernel 1255825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1; 1256825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1; 1257825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1258825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor; 1259825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1260825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (node_def.op() == "Mean") { 1261825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IPoolingLayer* layer = 1262825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor), 1263825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::PoolingType::kAVERAGE, pool_kernel); 1264825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama output_tensor = layer->getOutput(0); 1265825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else { 1266825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 1267825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Op not supported " + node_def.op() + " , at " + node_def.name()); 1268825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1269825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (permuted_index != -1) { 1270bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Apply permutation before extracting dimension for pool_kernel 1271bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney output_tensor = ctx.TransposeTensor( 1272825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order); 1273825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1274825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1275825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1276825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1277825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertPad(Converter& ctx, 1278bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def, 1279825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights> const& inputs, 1280825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<TRT_TensorOrWeights>* outputs) { 1281825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (inputs.size() != 2 || !inputs.at(0).is_tensor() || 1282825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama !inputs.at(1).is_weights()) 1283825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1284825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Input expects tensor and weights, at" + node_def.name()); 1285825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1286bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Implement tensor binaryOp weight [channel wise] for now; 1287825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1288825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto dims = tensor->getDimensions(); 1289bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Restore implicit batch dimension 1290bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney int nb_dims = dims.nbDims + 1; 1291825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1292825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TRT_ShapedWeights pads = inputs.at(1).weights(); 1293825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1294825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TFAttrs attrs(node_def); 1295bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Padding type here is done through TF type 1296825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // so I can leverage their EnumToDataType for my cast 1297825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings"); 1298825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): handle data type conversion for TRT? 1299825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1300bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) 1301825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1302825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Pad only supports explicit padding on 4 dimensional tensor, at " + 1303825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_def.name()); 1304825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1305825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Only expect to handle INT32 as attributes for now 1306825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (padding_type != tensorflow::DataType::DT_INT32) 1307825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 1308825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Tpaddings supports only DT_INT32"); 1309cfa374cefe132be886c26a374c51454177c68868gracehoney auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues())); 1310825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1311825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<int32_t> pad_index; 1312bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney for (int i = 0; i < nb_dims; i++) { 1313825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) 1314825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama pad_index.push_back(i); 1315825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1316825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1317bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // No padding at all, we should exit 1318825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (pad_index.size() == 0) { 1319825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(inputs.at(0)); 1320825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1321825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1322825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1323bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Only supports padding on less than 2 axis GIE-2579 1324825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (pad_index.size() > 2) 1325825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1326825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Padding layer does not support padding on > 2"); 1327825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1328bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Padding on batch dimension is not supported 1329825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (pad_index[0] == 0) 1330825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1331825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Padding layer does not support padding on batch dimension"); 1332825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1333bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Not doing the legit thing here. ignoring padding on dim 1 and 3; 1334825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): implement pad as uff parser 1335825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) 1336825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Unimplemented( 1337825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Padding layer does not support padding on dimension 1 and 3 yet"); 1338825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1339825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama bool legit_pad = true; 1340825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW pre_padding(0, 0); 1341825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsHW post_padding(0, 0); 1342825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1343825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<int32_t> permuted_pad_index(pad_index); 1344825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (pad_index[0] == 1) { 1345825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama legit_pad = false; 1346bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 1347825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama {0, 3, 2, 1}); 1348825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama permuted_pad_index[0] = 3; 1349825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1350825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1351825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (size_t i = 0; i < pad_index.size(); i++) { 1352825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int index = pad_index[i]; 1353825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (permuted_pad_index[i] == 2) { 1354825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama pre_padding.h() = pad_data[index * 2]; 1355825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama post_padding.h() = pad_data[index * 2 + 1]; 1356825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } else if (permuted_pad_index[i] == 3) { 1357825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama pre_padding.w() = pad_data[index * 2]; 1358825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama post_padding.w() = pad_data[index * 2 + 1]; 1359825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1360825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1361825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1362825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( 1363825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding); 1364825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1365825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1366825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!legit_pad) 1367bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney output_tensor = ctx.TransposeTensor( 1368825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1}); 1369825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1370825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1371825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1372825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1373825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1374825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamavoid Converter::register_op_converters() { 1375825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // vgg_16 slim implementation 1376cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Placeholder"] = ConvertPlaceholder; 1377cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Conv2D"] = ConvertConv2D; 1378cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Relu"] = ConvertActivation; 1379cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["MaxPool"] = ConvertPool; 1380825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // This could be really handled as ConvertBinary 1381cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["BiasAdd"] = ConvertScale; 1382cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Const"] = ConvertConst; 1383cd63c718be123324b6c39e0f8fbe453319799746Jie // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg 1384825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(ben,jie): this is a temp hack. 1385cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Identity"] = ConvertIdentity; // Identity should be removed 1386cd63c718be123324b6c39e0f8fbe453319799746Jie // op_registry_["AvgPool"] = ConvertPool; 1387825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1388825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // resnet_50_v1 slim implementation 1389cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Add"] = ConvertBinary; 1390cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Mul"] = ConvertBinary; 1391cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Sub"] = ConvertBinary; 1392cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Rsqrt"] = ConvertUnary; 1393cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Mean"] = ConvertReduce; 1394cd63c718be123324b6c39e0f8fbe453319799746Jie op_registry_["Pad"] = ConvertPad; 1395825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(ben,jie): Add more ops 1396825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1397825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1398825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} // namespace 1399825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1400825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kamatensorflow::Status ConvertSubGraphToTensorRTNodeDef( 1401825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids, 1402825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const std::vector<std::pair<int, int>>& input_inds, 1403825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size, 1404f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama size_t max_workspace_size_bytes, 140568e17d497497119c24ad506dac4e34e127cf836cJie const tensorflow::grappler::GraphProperties& graph_properties, 1406825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::NodeDef* trt_node) { 1407825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Visit nodes in reverse topological order and construct the TRT network. 1408825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1409825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Toposort 1410825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<tensorflow::Node*> order_vec; 1411825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::GetPostOrder(graph, &order_vec); 1412825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Select just the subgraph 1413825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::list<tensorflow::Node*> order; 1414825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (tensorflow::Node* node : order_vec) { 1415825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (subgraph_node_ids.count(node->id())) { 1416bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // We want topological order to contstruct the 1417bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // network layer by layer 1418bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney order.push_front(node); 1419825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1420825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1421bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // Topological order is needed to build TRT network 1422825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1423825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::tensorrt::Logger trt_logger; 1424825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1425825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); 1426825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!trt_builder) { 1427825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Internal( 1428bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Failed to create TensorRT builder object"); 1429825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1430825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1431825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto trt_network = infer_object(trt_builder->createNetwork()); 1432825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!trt_network) { 1433825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Internal( 1434bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Failed to create TensorRT network object"); 1435825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1436825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1437825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Build the network 1438825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama Converter converter(trt_network.get()); 1439825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1440bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney std::vector<string> input_names; 1441825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<tensorflow::DataType> input_dtypes; 1442825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (std::pair<int, int> const& input : input_inds) { 1443825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int node_id = input.first; 1444825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int output_idx = input.second; 1445825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::Node* node = graph.FindNodeId(node_id); 1446825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto node_name = node->name(); 1447bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney input_names.push_back(node_name); // Insert original node name without port 1448825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(jie): alternative :) 144968e17d497497119c24ad506dac4e34e127cf836cJie if (!graph_properties.HasOutputProperties(node_name)) 1450bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney return tensorflow::errors::Internal("Failed to find input node: " + 1451825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama node_name); 1452825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 145368e17d497497119c24ad506dac4e34e127cf836cJie auto op_info_vec = graph_properties.GetOutputProperties(node_name); 145468e17d497497119c24ad506dac4e34e127cf836cJie if (static_cast<int>(op_info_vec.size()) < output_idx) 1455825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::Internal( 1456bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Accessing output index of: " + std::to_string(output_idx) + 1457bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney ", at node: " + node_name + " with output entry from shape_map: " + 145868e17d497497119c24ad506dac4e34e127cf836cJie std::to_string(op_info_vec.size())); 1459825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 146068e17d497497119c24ad506dac4e34e127cf836cJie auto op_info = op_info_vec.at(output_idx); 1461825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 146268e17d497497119c24ad506dac4e34e127cf836cJie tensorflow::DataType tf_dtype = op_info.dtype(); 1463825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama input_dtypes.push_back(tf_dtype); 1464825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1465825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); 1466bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TF_CHECK_OK(ConvertDType(tf_dtype, &dtype)); 1467825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1468bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Accessing output index of: " << std::to_string(output_idx) 1469f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << ", at node: " << node_name 1470bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney << " with output entry from shape_map: " 1471f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << std::to_string(op_info_vec.size()); 147268e17d497497119c24ad506dac4e34e127cf836cJie 1473825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(ben,jie): update TRT input format/dimension 1474825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DimsCHW input_dim_psuedo_chw; 1475825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; 1476825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 147768e17d497497119c24ad506dac4e34e127cf836cJie for (int i = 1; i < op_info.shape().dim_size(); i++) { 1478f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(2) << "dimension: " << i 1479f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama << " , size: " << op_info.shape().dim(i).size(); 148068e17d497497119c24ad506dac4e34e127cf836cJie input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size(); 1481825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1482825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1483825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(ben,jie): proper way to restore input tensor name? 1484825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto input_tensor_name = node_name; 1485825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (output_idx != 0) 1486825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama input_tensor_name = node_name + ":" + std::to_string(output_idx); 1487825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1488825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* input_tensor = converter.network()->addInput( 1489825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); 1490825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1491825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!input_tensor) 1492825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1493825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Failed to create Input layer"); 1494bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Input tensor name :" << input_tensor_name; 1495825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1496825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) 1497825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::AlreadyExists( 1498bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney "Output tensor already exists for op: " + input_tensor_name); 1499825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1500825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1501bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Finished sorting"; 1502825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1503825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (const tensorflow::Node* node : order) { 1504bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney const tensorflow::NodeDef& node_def = node->def(); 1505bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); 1506825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama TF_RETURN_IF_ERROR(converter.convert_node(node_def)); 1507825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1508825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1509bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Finished conversion"; 1510825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1511825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Gather output metadata 1512bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney std::vector<string> output_names; 1513825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<tensorflow::DataType> output_dtypes; 1514825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (std::pair<int, int> const& output : output_inds) { 1515825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int node_id = output.first; 1516825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int output_idx = output.second; 1517825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::Node* node = graph.FindNodeId(node_id); 1518bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney string op_name = node->name(); 1519bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney string tensor_name = op_name; 1520825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (output_idx != 0) 1521825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensor_name = tensor_name + ":" + std::to_string(output_idx); 1522bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Output tensor name: " << tensor_name; 1523825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama output_names.push_back(tensor_name); 1524825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto tensor_or_weights = converter.get_tensor(tensor_name); 1525825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!tensor_or_weights.is_tensor()) { 1526825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::InvalidArgument( 1527825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama "Output node is weights not tensor"); 1528825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1529825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); 1530825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama if (!tensor) { 1531825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::errors::NotFound("Output tensor not found: " + 1532825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensor_name); 1533825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1534825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama converter.network()->markOutput(*tensor); 1535825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::DataType tf_dtype = node->output_type(output_idx); 1536825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama output_dtypes.push_back(tf_dtype); 1537825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; 1538bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); 1539825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensor->setType(trt_dtype); 1540825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1541825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1542bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(2) << "Finished output"; 1543cd63c718be123324b6c39e0f8fbe453319799746Jie // TODO(jie): static_id is not thread safe. 1544599eadc299ae680bfb569ace4278b2eb262ecc44Sami Kama static int static_id = 0; 1545825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1546825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Build the engine 1547825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama trt_builder->setMaxBatchSize(max_batch_size); 1548f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes); 1549bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(0) << "Starting build engine " << static_id; 1550825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(ben,jie): half2 and int8 mode support 1551bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney string engine_plan_string; 1552825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama { 1553825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto trt_engine = 1554825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama infer_object(trt_builder->buildCudaEngine(*converter.network())); 1555bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(0) << "Built network"; 1556825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto engine_plan = infer_object(trt_engine->serialize()); 1557bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(0) << "Serialized engine"; 1558825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama const char* engine_plan_data = 1559825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama static_cast<const char*>(engine_plan->data()); 1560bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney engine_plan_string = 1561bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney string(engine_plan_data, engine_plan_data + engine_plan->size()); 1562825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1563825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1564bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(0) << "Finished engine"; 1565825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1566825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // Build the TRT op 1567825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama // TODO(sami,ben,jie): proper naming! 1568825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama tensorflow::NodeDefBuilder op_builder( 1569bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp"); 1570825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges; 1571825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama for (size_t i = 0; i < input_names.size(); ++i) { 1572825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama int output_idx = input_inds.at(i).second; 1573bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // We wired up the input here already, it is redundant to do it again in 1574bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney // ConvertSubGraphToTensorRT(convert_graph.cc) 1575e01844e65e0dbd2682a894946bec7f072d36fa27Jie auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( 1576e01844e65e0dbd2682a894946bec7f072d36fa27Jie input_names.at(i), output_idx, input_dtypes.at(i)); 1577825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama income_edges.push_back(incoming_edge); 1578825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama } 1579e01844e65e0dbd2682a894946bec7f072d36fa27Jie tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list( 1580e01844e65e0dbd2682a894946bec7f072d36fa27Jie income_edges); 1581825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama op_builder.Input(input_list); 1582825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1583bfe8b85cad3be1a82234500fce3064c98dd20d09gracehoney VLOG(0) << "Finished op preparation"; 1584825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1585825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama auto status = op_builder.Attr("serialized_engine", engine_plan_string) 1586825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama .Attr("input_nodes", input_names) 1587825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama .Attr("output_nodes", output_names) 1588825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama .Attr("OutT", output_dtypes) 1589825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama .Finalize(trt_node); 1590825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1591f8b1986d67b1bcc352acb7644b642faf46ca79cbSami Kama VLOG(0) << status.ToString() << " finished op building"; 1592825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1593825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama return tensorflow::Status::OK(); 1594825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} 1595825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama 1596825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} // namespace convert 1597825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0eSami Kama} // namespace tensorrt 15986908cc233c679b8fe61d99a30d3828362caf47beSami Kama} // namespace tensorflow 1599ae740a67bdc01b991ead6ac047c774bff4d7bc8fJie 1600e01844e65e0dbd2682a894946bec7f072d36fa27Jie#endif // GOOGLE_TENSORRT 1601e01844e65e0dbd2682a894946bec7f072d36fa27Jie#endif // GOOGLE_CUDA 1602