1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/platform/logging.h"
17#include "tensorflow/core/platform/test.h"
18
19#if GOOGLE_CUDA
20#if GOOGLE_TENSORRT
21#include "cuda/include/cuda.h"
22#include "cuda/include/cuda_runtime_api.h"
23#include "tensorrt/include/NvInfer.h"
24
25namespace tensorflow {
26namespace {
27
28class Logger : public nvinfer1::ILogger {
29 public:
30  void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
31    switch (severity) {
32      case Severity::kINFO:
33        LOG(INFO) << msg;
34        break;
35      case Severity::kWARNING:
36        LOG(WARNING) << msg;
37        break;
38      case Severity::kINTERNAL_ERROR:
39      case Severity::kERROR:
40        LOG(ERROR) << msg;
41        break;
42      default:
43        break;
44    }
45  }
46};
47
48class ScopedWeights {
49 public:
50  ScopedWeights(float value) : value_(value) {
51    w.type = nvinfer1::DataType::kFLOAT;
52    w.values = &value_;
53    w.count = 1;
54  }
55  const nvinfer1::Weights& get() { return w; }
56
57 private:
58  float value_;
59  nvinfer1::Weights w;
60};
61
62const char* kInputTensor = "input";
63const char* kOutputTensor = "output";
64
65// Creates a network to compute y=2x+3.
66nvinfer1::IHostMemory* CreateNetwork() {
67  Logger logger;
68  nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
69  ScopedWeights weights(2.0);
70  ScopedWeights bias(3.0);
71
72  nvinfer1::INetworkDefinition* network = builder->createNetwork();
73  // Add the input.
74  auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
75                                 nvinfer1::DimsCHW{1, 1, 1});
76  EXPECT_NE(input, nullptr);
77  // Add the hidden layer.
78  auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
79  EXPECT_NE(layer, nullptr);
80  // Mark the output.
81  auto output = layer->getOutput(0);
82  output->setName(kOutputTensor);
83  network->markOutput(*output);
84  // Build the engine
85  builder->setMaxBatchSize(1);
86  builder->setMaxWorkspaceSize(1 << 10);
87  auto engine = builder->buildCudaEngine(*network);
88  EXPECT_NE(engine, nullptr);
89  // Serialize the engine to create a model, then close everything.
90  nvinfer1::IHostMemory* model = engine->serialize();
91  network->destroy();
92  engine->destroy();
93  builder->destroy();
94  return model;
95}
96
97// Executes the network.
98void Execute(nvinfer1::IExecutionContext& context, const float* input,
99             float* output) {
100  const nvinfer1::ICudaEngine& engine = context.getEngine();
101
102  // We have two bindings: input and output.
103  ASSERT_EQ(engine.getNbBindings(), 2);
104  const int input_index = engine.getBindingIndex(kInputTensor);
105  const int output_index = engine.getBindingIndex(kOutputTensor);
106
107  // Create GPU buffers and a stream
108  void* buffers[2];
109  ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
110  ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
111  cudaStream_t stream;
112  ASSERT_EQ(0, cudaStreamCreate(&stream));
113
114  // Copy the input to the GPU, execute the network, and copy the output back.
115  //
116  // Note that since the host buffer was not created as pinned memory, these
117  // async copies are turned into sync copies. So the following synchronization
118  // could be removed.
119  ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
120                               cudaMemcpyHostToDevice, stream));
121  context.enqueue(1, buffers, stream, nullptr);
122  ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
123                               cudaMemcpyDeviceToHost, stream));
124  cudaStreamSynchronize(stream);
125
126  // Release the stream and the buffers
127  cudaStreamDestroy(stream);
128  ASSERT_EQ(0, cudaFree(buffers[input_index]));
129  ASSERT_EQ(0, cudaFree(buffers[output_index]));
130}
131
132TEST(TensorrtTest, BasicFunctions) {
133  // Create the network model.
134  nvinfer1::IHostMemory* model = CreateNetwork();
135  // Use the model to create an engine and then an execution context.
136  Logger logger;
137  nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
138  nvinfer1::ICudaEngine* engine =
139      runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
140  model->destroy();
141  nvinfer1::IExecutionContext* context = engine->createExecutionContext();
142
143  // Execute the network.
144  float input = 1234;
145  float output;
146  Execute(*context, &input, &output);
147  EXPECT_EQ(output, input * 2 + 3);
148
149  // Destroy the engine.
150  context->destroy();
151  engine->destroy();
152  runtime->destroy();
153}
154
155}  // namespace
156}  // namespace tensorflow
157
158#endif  // GOOGLE_TENSORRT
159#endif  // GOOGLE_CUDA
160