1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
17
18#include <map>
19#include <set>
20#include <unordered_map>
21#include <utility>
22#include <vector>
23
24#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
25#include "tensorflow/contrib/tensorrt/segment/segment.h"
26#include "tensorflow/core/graph/algorithm.h"
27#include "tensorflow/core/graph/graph.h"
28#include "tensorflow/core/graph/graph_constructor.h"
29#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30#include "tensorflow/core/grappler/costs/graph_properties.h"
31#include "tensorflow/core/grappler/devices.h"
32#include "tensorflow/core/grappler/grappler_item.h"
33#include "tensorflow/core/grappler/optimizers/constant_folding.h"
34#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
35#include "tensorflow/core/grappler/utils.h"
36#include "tensorflow/core/lib/core/errors.h"
37#include "tensorflow/core/lib/core/status.h"
38#include "tensorflow/core/platform/logging.h"
39#include "tensorflow/core/platform/types.h"
40#include "tensorflow/core/protobuf/device_properties.pb.h"
41
42#if GOOGLE_CUDA
43#if GOOGLE_TENSORRT
44#include "tensorrt/include/NvInfer.h"
45
46namespace tensorflow {
47namespace tensorrt {
48namespace convert {
49namespace {
50
51static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
52  // LINT.IfChange
53  // TODO(jie): Segmentation shouldn't associated with op name.
54  //            Split it into a registration for each kernel.
55  static const std::set<string> candidate_ops = {
56      "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
57      "Add",      "Mul",   "Sub",    "Rsqrt",   "Pad"  // "Placeholder" ,"Mean"
58  };
59  // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
60  return candidate_ops.count(node_def.op());
61}
62
63void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
64                              const std::set<int>& subgraph_node_ids,
65                              tensorflow::EdgeSet* incoming_edges) {
66  for (int node_id : subgraph_node_ids) {
67    const tensorflow::Node* node = graph.FindNodeId(node_id);
68    for (const tensorflow::Edge* edge : node->in_edges()) {
69      if (!subgraph_node_ids.count(edge->src()->id()) &&
70          !edge->src()->IsSource()) {
71        incoming_edges->insert(edge);
72      }
73    }
74  }
75}
76
77void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
78                              const std::set<int>& subgraph_node_ids,
79                              tensorflow::EdgeSet* outgoing_edges) {
80  for (int node_id : subgraph_node_ids) {
81    const tensorflow::Node* node = graph.FindNodeId(node_id);
82    for (const tensorflow::Edge* edge : node->out_edges()) {
83      if (!subgraph_node_ids.count(edge->dst()->id()) &&
84          !edge->dst()->IsSink()) {
85        outgoing_edges->insert(edge);
86      }
87    }
88  }
89}
90
91std::pair<string, int> ParseTensorName(string name, int default_idx = 0) {
92  int idx = default_idx;
93  size_t sep = name.find_last_of(':');
94  if (sep != string::npos) {
95    name = name.substr(0, sep);
96    idx = std::stoi(name.substr(sep + 1));
97  }
98  return std::make_pair(name, idx);
99}
100
101std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
102    const std::vector<string>& tensor_names) {
103  std::unordered_map<string, std::vector<int>> result;
104  for (string const& tensor_name : tensor_names) {
105    string node_name;
106    int index;
107    std::tie(node_name, index) = ParseTensorName(tensor_name);
108    result[node_name].push_back(index);
109  }
110  return result;
111}
112
113tensorflow::Status ConvertSubGraphToTensorRT(
114    const std::vector<string>& output_names,
115    const std::set<int>& subgraph_node_ids,
116    size_t max_batch_size,  // Max batch size that engine will be created for
117    // Max amount of memory that engine will be allowed to consume, in bytes
118    size_t max_workspace_size_bytes,
119    const tensorflow::grappler::GraphProperties& graph_properties,
120    tensorflow::Graph* graph) {
121  tensorflow::EdgeSet subgraph_incoming_edges;
122  GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
123
124  std::vector<std::pair<int, int>> subgraph_inputs;
125
126  // Collect inputs by looking for incoming edges
127  for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
128    subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
129  }
130  std::set<std::pair<int, int>> subgraph_outputs_set;
131  // Collect outputs referenced from output_names
132  auto output_name_to_index_map = BuildTensorNameMap(output_names);
133  for (int node_id : subgraph_node_ids) {
134    tensorflow::Node* node = graph->FindNodeId(node_id);
135    if (output_name_to_index_map.count(node->name())) {
136      for (int index : output_name_to_index_map.at(node->name())) {
137        subgraph_outputs_set.insert({node_id, index});
138      }
139    }
140  }
141  // Collect outputs referenced from outgoing edges
142  tensorflow::EdgeSet subgraph_outgoing_edges;
143  GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
144  for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
145    subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
146  }
147  // Impose an ordering on the outputs
148  std::vector<std::pair<int, int>> subgraph_outputs(
149      subgraph_outputs_set.begin(), subgraph_outputs_set.end());
150  // Build TensorRT node and add it to the graph
151  tensorflow::NodeDef trt_node_def;
152  TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
153      *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
154      max_batch_size, max_workspace_size_bytes, graph_properties,
155      &trt_node_def));
156  tensorflow::Status status;
157  tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
158  TF_RETURN_IF_ERROR(status);
159
160  // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
161  std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
162  for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
163    subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
164  }
165  TF_RETURN_IF_ERROR(status);
166  for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
167    std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
168    int new_src_output = subgraph_edge_to_output_map.at(old_src);
169    TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
170                                         edge->dst_input()));
171  }
172  // Remove the original subgraph
173  for (int node_id : subgraph_node_ids) {
174    tensorflow::Node* node = graph->FindNodeId(node_id);
175    // Don't remove the input placeholders
176    if (node->type_string() == "Placeholder") {
177      continue;
178    }
179    graph->RemoveNode(node);
180  }
181  return tensorflow::Status::OK();
182}
183
184tensorflow::Status BuildNodeMap(
185    const tensorflow::Graph& graph,
186    std::unordered_map<string, tensorflow::Node*>* node_map) {
187  for (auto* node : graph.op_nodes()) {
188    if (!node_map->insert({node->name(), node}).second) {
189      return tensorflow::errors::AlreadyExists(
190          "Node name is not unique in graph: " + node->name());
191    }
192  }
193  return tensorflow::Status::OK();
194}
195
196}  // namespace
197
198tensorflow::Status ConvertGraphDefToTensorRT(
199    const tensorflow::GraphDef& graph_def,
200    const std::vector<string>& output_names, size_t max_batch_size,
201    size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
202  // Optimization pass
203  tensorflow::grappler::GrapplerItem item;
204  item.fetch = output_names;
205  tensorflow::GraphDef gdef;
206
207  // Layout optimization
208  item.graph = graph_def;
209  tensorflow::grappler::LayoutOptimizer optimizer;
210  tensorflow::grappler::Cluster* cluster;
211
212  // Virtual cluster
213  tensorflow::DeviceProperties device_properties;
214  device_properties.set_type("GPU");
215  device_properties.mutable_environment()->insert({"architecture", "6"});
216  cluster =
217      new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
218
219  TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
220
221  // Constant folding
222  item.graph = gdef;
223  tensorflow::grappler::ConstantFolding fold(nullptr);
224  TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
225
226  // AJ refactoring shape inference through grappler/GraphProperties.
227  tensorflow::grappler::GraphProperties static_graph_properties(item);
228  TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
229
230  // Build full graph
231  tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
232                                             gdef.library());
233  tensorflow::Graph graph(flib);
234  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
235      tensorflow::GraphConstructorOptions(), gdef, &graph));
236
237  // Segment the graph into subgraphs that can be converted to TensorRT
238  tensorflow::tensorrt::segment::SegmentOptions segment_options;
239
240  // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
241  for (auto node : output_names) {
242    segment_options.exclude_node_list.insert(node);
243  }
244
245  // TODO(sami): this should be passed as a knob!!!!
246  segment_options.minimum_segment_size = 2;
247  tensorflow::tensorrt::segment::SegmentNodesVector segments;
248  TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
249      gdef, IsTensorRTCandidate, segment_options, &segments));
250  if (segments.size() > 1) {
251    VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
252  }
253  std::unordered_map<string, tensorflow::Node*> node_map;
254  TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
255  for (const std::set<string>& subgraph_node_names : segments) {
256    std::set<int> subgraph_node_ids;
257    for (const string& node_name : subgraph_node_names) {
258      subgraph_node_ids.insert(node_map.at(node_name)->id());
259    }
260    TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
261        output_names, subgraph_node_ids, max_batch_size,
262        max_workspace_size_bytes, static_graph_properties, &graph));
263  }
264  graph.ToGraphDef(new_graph_def);
265  return tensorflow::Status::OK();
266}
267
268}  // namespace convert
269}  // namespace tensorrt
270}  // namespace tensorflow
271
272#endif  // GOOGLE_TENSORRT
273#endif  // GOOGLE_CUDA
274