convert_graph.cc revision 599eadc299ae680bfb569ace4278b2eb262ecc44
1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
17
18#include <list>
19#include <map>
20#include <set>
21#include <sstream>
22#include <string>
23#include <unordered_map>
24#include <unordered_set>
25#include <utility>
26#include <vector>
27
28#include "tensorflow/core/framework/graph.pb.h"
29#include "tensorflow/core/framework/node_def.pb.h"
30#include "tensorflow/core/graph/algorithm.h"
31#include "tensorflow/core/graph/graph.h"
32#include "tensorflow/core/graph/graph_constructor.h"
33#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
34#include "tensorflow/core/grappler/costs/graph_properties.h"
35#include "tensorflow/core/grappler/devices.h"
36#include "tensorflow/core/grappler/grappler_item.h"
37#include "tensorflow/core/grappler/optimizers/constant_folding.h"
38#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
39#include "tensorflow/core/grappler/utils.h"
40#include "tensorflow/core/lib/core/errors.h"
41#include "tensorflow/core/lib/core/status.h"
42#include "tensorflow/core/platform/logging.h"
43#include "tensorflow/core/protobuf/device_properties.pb.h"
44
45#if GOOGLE_CUDA
46#if GOOGLE_TENSORRT
47#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
48#include "tensorflow/contrib/tensorrt/segment/segment.h"
49#include "tensorrt/include/NvInfer.h"
50
51//------------------------------------------------------------------------------
52namespace tensorflow {
53namespace tensorrt {
54namespace convert {
55namespace {
56
57static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
58  // LINT.IfChange
59  // TODO(jie): Segmentation shouldn't associated with op name.
60  //            Split it into a registration for each kernel.
61  static const std::set<std::string> candidate_ops = {
62      "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
63      "Add",      "Mul",   "Sub",    "Rsqrt",   "Pad"  // "Placeholder" ,"Mean"
64  };
65  // LINT.ThenChange(
66  //    https://www.tensorflow.org/code/tensorflow/contrib/tensorrt/convert/convert_nodes.h)
67  return candidate_ops.count(node_def.op());
68}
69
70void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
71                              const std::set<int>& subgraph_node_ids,
72                              tensorflow::EdgeSet* incoming_edges) {
73  for (int node_id : subgraph_node_ids) {
74    const tensorflow::Node* node = graph.FindNodeId(node_id);
75    for (const tensorflow::Edge* edge : node->in_edges()) {
76      if (!subgraph_node_ids.count(edge->src()->id()) &&
77          !edge->src()->IsSource()) {
78        incoming_edges->insert(edge);
79      }
80    }
81  }
82}
83
84void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
85                              const std::set<int>& subgraph_node_ids,
86                              tensorflow::EdgeSet* outgoing_edges) {
87  for (int node_id : subgraph_node_ids) {
88    const tensorflow::Node* node = graph.FindNodeId(node_id);
89    for (const tensorflow::Edge* edge : node->out_edges()) {
90      if (!subgraph_node_ids.count(edge->dst()->id()) &&
91          !edge->dst()->IsSink()) {
92        outgoing_edges->insert(edge);
93      }
94    }
95  }
96}
97
98std::pair<std::string, int> ParseTensorName(std::string name,
99                                            int default_idx = 0) {
100  int idx = default_idx;
101  size_t sep = name.find_last_of(':');
102  if (sep != std::string::npos) {
103    name = name.substr(0, sep);
104    idx = std::stoi(name.substr(sep + 1));
105  }
106  return std::make_pair(name, idx);
107}
108
109std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
110    const std::vector<std::string>& tensor_names) {
111  std::unordered_map<std::string, std::vector<int>> result;
112  for (std::string const& tensor_name : tensor_names) {
113    std::string node_name;
114    int index;
115    std::tie(node_name, index) = ParseTensorName(tensor_name);
116    result[node_name].push_back(index);
117  }
118  return result;
119}
120
121tensorflow::Status ConvertSubGraphToTensorRT(
122    const std::vector<std::string>& output_names,
123    const std::set<int>& subgraph_node_ids,
124    size_t max_batch_size,  // max batch size that engine will be created for
125    // max amount of memory that engine will be allowed to consume, in bytes
126    size_t max_workspace_size,
127    const tensorflow::grappler::GraphProperties& graph_properties,
128    tensorflow::Graph* graph) {
129  tensorflow::EdgeSet subgraph_incoming_edges;
130  GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
131
132  std::vector<std::pair<int, int>> subgraph_inputs;
133
134  // Collect inputs by looking for incoming edges
135  for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
136    subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
137  }
138  std::set<std::pair<int, int>> subgraph_outputs_set;
139  // Collect outputs referenced from output_names
140  auto output_name_to_index_map = BuildTensorNameMap(output_names);
141  for (int node_id : subgraph_node_ids) {
142    tensorflow::Node* node = graph->FindNodeId(node_id);
143    if (output_name_to_index_map.count(node->name())) {
144      for (int index : output_name_to_index_map.at(node->name())) {
145        subgraph_outputs_set.insert({node_id, index});
146      }
147    }
148  }
149  // Collect outputs referenced from outgoing edges
150  tensorflow::EdgeSet subgraph_outgoing_edges;
151  GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
152  for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
153    subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
154  }
155  // Impose an ordering on the outputs
156  std::vector<std::pair<int, int>> subgraph_outputs(
157      subgraph_outputs_set.begin(), subgraph_outputs_set.end());
158  // Build TensorRT node and add it to the graph
159  tensorflow::NodeDef trt_node_def;
160  TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
161      *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
162      max_batch_size, max_workspace_size, graph_properties, &trt_node_def));
163  tensorflow::Status status;
164  tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
165
166  TF_RETURN_IF_ERROR(status);
167
168  // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
169  std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
170  for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
171    subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
172  }
173  TF_RETURN_IF_ERROR(status);
174  for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
175    std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
176    int new_src_output = subgraph_edge_to_output_map.at(old_src);
177    graph->UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
178  }
179  // Remove the original subgraph
180  for (int node_id : subgraph_node_ids) {
181    tensorflow::Node* node = graph->FindNodeId(node_id);
182    // Don't remove the input placeholders
183    if (node->type_string() == "Placeholder") {
184      continue;
185    }
186    graph->RemoveNode(node);
187  }
188  return tensorflow::Status::OK();
189}
190
191tensorflow::Status BuildNodeMap(
192    const tensorflow::Graph& graph,
193    std::unordered_map<std::string, tensorflow::Node*>* node_map) {
194  for (auto* node : graph.op_nodes()) {
195    if (!node_map->insert({node->name(), node}).second) {
196      return tensorflow::errors::AlreadyExists(
197          "Node name is not unique in graph: " + node->name());
198    }
199  }
200  return tensorflow::Status::OK();
201}
202
203}  // namespace
204
205tensorflow::Status ConvertGraphDefToTensorRT(
206    const tensorflow::GraphDef& graph_def,
207    const std::vector<std::string>& output_names, size_t max_batch_size,
208    size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
209  // optimization pass
210  tensorflow::grappler::GrapplerItem item;
211  item.fetch = output_names;
212  tensorflow::GraphDef gdef;
213
214  // layout optimization
215  item.graph = graph_def;
216  tensorflow::grappler::LayoutOptimizer optimizer;
217  tensorflow::grappler::Cluster* cluster;
218
219  // virtual cluster
220  tensorflow::DeviceProperties device_properties;
221  device_properties.set_type("GPU");
222  device_properties.mutable_environment()->insert({"architecture", "6"});
223  cluster =
224      new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
225
226  tensorflow::Status status = optimizer.Optimize(cluster, item, &gdef);
227
228  if (status != tensorflow::Status::OK()) return status;
229
230  // constant folding
231  item.graph = gdef;
232  tensorflow::grappler::ConstantFolding fold(nullptr);
233  status = fold.Optimize(nullptr, item, &gdef);
234  if (status != tensorflow::Status::OK()) return status;
235
236  // AJ refactoring shape inference through grappler/GraphProperties.
237  tensorflow::grappler::GraphProperties static_graph_properties(item);
238  static_graph_properties.InferStatically(false);
239
240  // Build full graph
241  tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
242                                             gdef.library());
243  tensorflow::Graph graph(flib);
244  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
245      tensorflow::GraphConstructorOptions(), gdef, &graph));
246
247  // Segment the graph into subgraphs that can be converted to TensorRT
248  tensorflow::tensorrt::segment::SegmentOptions segment_options;
249
250  // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
251  for (auto node : output_names) {
252    segment_options.exclude_node_list.insert(node);
253  }
254
255  // TODO(sami): this should be passed as a knob!!!!
256  segment_options.minimum_segment_size = 2;
257  tensorflow::tensorrt::segment::SegmentNodesVector segments;
258  TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
259      gdef, IsTensorRTCandidate, segment_options, &segments));
260  if (segments.size() > 1) {
261    VLOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
262  }
263  std::unordered_map<std::string, tensorflow::Node*> node_map;
264  TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
265  for (const std::set<std::string>& subgraph_node_names : segments) {
266    std::set<int> subgraph_node_ids;
267    for (const std::string& node_name : subgraph_node_names) {
268      subgraph_node_ids.insert(node_map.at(node_name)->id());
269    }
270    TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
271        output_names, subgraph_node_ids, max_batch_size, max_workspace_size,
272        static_graph_properties, &graph));
273  }
274  graph.ToGraphDef(new_graph_def);
275  return tensorflow::Status::OK();
276}
277
278}  // namespace convert
279}  // namespace tensorrt
280}  // namespace tensorflow
281
282#endif  // GOOGLE_TENSORRT
283#endif  // GOOGLE_CUDA
284