convert_graph.cc revision 6908cc233c679b8fe61d99a30d3828362caf47be
1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
17
18#include <list>
19#include <set>
20#include <sstream>
21#include <string>
22#include <unordered_map>
23#include <unordered_set>
24#include <vector>
25#include <map>
26#include <utility>
27
28#include "NvInfer.h"
29
30#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
31#include "tensorflow/contrib/tensorrt/segment/segment.h"
32#include "tensorflow/core/framework/graph.pb.h"
33#include "tensorflow/core/framework/node_def.pb.h"
34#include "tensorflow/core/graph/algorithm.h"
35#include "tensorflow/core/graph/graph.h"
36#include "tensorflow/core/graph/graph_constructor.h"
37#include "tensorflow/core/lib/core/errors.h"
38#include "tensorflow/core/lib/core/status.h"
39#include "tensorflow/core/platform/logging.h"
40
41#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
42#include "tensorflow/core/grappler/optimizers/constant_folding.h"
43#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
44#include "tensorflow/core/grappler/devices.h"
45#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
46#include "tensorflow/core/protobuf/device_properties.pb.h"
47#include "tensorflow/core/grappler/grappler_item.h"
48#include "tensorflow/core/grappler/utils.h"
49
50#include "tensorflow/core/grappler/costs/graph_properties.h"
51
52//------------------------------------------------------------------------------
53namespace tensorflow {
54namespace tensorrt {
55namespace convert {
56namespace {
57
58static std::unordered_set<std::string> output_nodes;
59bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
60// LINT.IfChange
61  static const std::set<std::string> candidate_ops = {
62      "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
63      "Add",      "Mul",   "Sub",    "Rsqrt",   "Pad"  // "Placeholder" ,"Mean"
64                                                       // TODO(ben,jie): ...
65  };
66// LINT.ThenChange(
67//    https://www.tensorflow.org/code/tensorflow/contrib/tensorrt/convert/convert_nodes.h)
68
69  if (output_nodes.count(node_def.name())) return false;
70  return candidate_ops.count(node_def.op());
71}
72
73void GetSubGraphIncomingEdges(const tensorflow::Graph & graph,
74                              const std::set<int> &subgraph_node_ids,
75                              tensorflow::EdgeSet* incoming_edges) {
76  for (int node_id : subgraph_node_ids) {
77    tensorflow::Node const* node = graph.FindNodeId(node_id);
78    LOG(DEBUG) << node->name() << " has incoming edges: ";
79    for (tensorflow::Edge const* edge : node->in_edges()) {
80      if (!subgraph_node_ids.count(edge->src()->id()) &&
81          !edge->src()->IsSource()) {
82        LOG(DEBUG) << edge->src()->name() << ", ";
83        incoming_edges->insert(edge);
84      }
85    }
86  }
87}
88
89void GetSubGraphOutgoingEdges(const tensorflow::Graph &graph,
90                              const std::set<int> &subgraph_node_ids,
91                              tensorflow::EdgeSet* outgoing_edges) {
92  for (int node_id : subgraph_node_ids) {
93    tensorflow::Node const* node = graph.FindNodeId(node_id);
94    LOG(DEBUG) << node->name() << " has outgoing edges: ";
95    for (tensorflow::Edge const* edge : node->out_edges()) {
96      if (!subgraph_node_ids.count(edge->dst()->id()) &&
97          !edge->dst()->IsSink()) {
98        outgoing_edges->insert(edge);
99      }
100    }
101  }
102}
103
104std::pair<std::string, int> ParseTensorName(std::string name,
105                                            int default_idx = 0) {
106  int idx = default_idx;
107  size_t sep = name.find_last_of(':');
108  if (sep != std::string::npos) {
109    name = name.substr(0, sep);
110    idx = std::stoi(name.substr(sep + 1));
111  }
112  return std::make_pair(name, idx);
113}
114
115std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
116    const std::vector<std::string>& tensor_names) {
117  std::unordered_map<std::string, std::vector<int>> result;
118  for (std::string const& tensor_name : tensor_names) {
119    std::string node_name;
120    int index;
121    std::tie(node_name, index) = ParseTensorName(tensor_name);
122    result[node_name].push_back(index);
123  }
124  return result;
125}
126
127tensorflow::Status ConvertSubGraphToTensorRT(
128    tensorflow::Graph& graph, const std::vector<std::string>& output_names,
129    const std::set<int>& subgraph_node_ids,
130    size_t max_batch_size, // max batch size that engine will be created for
131    // max amount of memory that engine will be allowed to consume, in bytes
132    size_t max_workspace_size,
133    const tensorflow::grappler::GraphProperties& graph_properties) {
134  tensorflow::EdgeSet subgraph_incoming_edges;
135  GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
136
137  std::vector<std::pair<int, int>> subgraph_inputs;
138
139
140  // Collect inputs by looking for incoming edges
141  for (tensorflow::Edge const* edge : subgraph_incoming_edges) {
142    subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
143  }
144  std::set<std::pair<int, int>> subgraph_outputs_set;
145  // Collect outputs referenced from output_names
146  auto output_name_to_index_map = BuildTensorNameMap(output_names);
147  for (int node_id : subgraph_node_ids) {
148    tensorflow::Node* node = graph.FindNodeId(node_id);
149    if (output_name_to_index_map.count(node->name())) {
150      for (int index : output_name_to_index_map.at(node->name())) {
151        subgraph_outputs_set.insert({node_id, index});
152      }
153    }
154  }
155  // Collect outputs referenced from outgoing edges
156  tensorflow::EdgeSet subgraph_outgoing_edges;
157  GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
158  for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
159    subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
160  }
161  // Impose an ordering on the outputs
162  std::vector<std::pair<int, int>> subgraph_outputs(
163      subgraph_outputs_set.begin(), subgraph_outputs_set.end());
164  // Build TensorRT node and add it to the graph
165  tensorflow::NodeDef trt_node_def;
166  TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
167      graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
168      max_batch_size, max_workspace_size, graph_properties, &trt_node_def));
169  tensorflow::Status status;
170  tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
171
172  TF_RETURN_IF_ERROR(status);
173
174  // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
175  std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
176  for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
177    subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
178  }
179  TF_RETURN_IF_ERROR(status);
180  for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
181    std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
182    int new_src_output = subgraph_edge_to_output_map.at(old_src);
183    graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
184  }
185  // Remove the original subgraph
186  for (int node_id : subgraph_node_ids) {
187    tensorflow::Node* node = graph.FindNodeId(node_id);
188    // Don't remove the input placeholders
189    if (node->type_string() == "Placeholder") {
190      continue;
191    }
192    graph.RemoveNode(node);
193  }
194  return tensorflow::Status::OK();
195}
196
197tensorflow::Status BuildNodeMap(
198    const tensorflow::Graph& graph,
199    std::unordered_map<std::string, tensorflow::Node*>* node_map) {
200  for (auto* node : graph.op_nodes()) {
201    if (!node_map->insert({node->name(), node}).second) {
202      return tensorflow::errors::AlreadyExists(
203          "Node name is not unique in graph: " + node->name());
204    }
205  }
206  return tensorflow::Status::OK();
207}
208
209}  // namespace
210
211tensorflow::Status ConvertGraphDefToTensorRT(
212    const tensorflow::GraphDef& graph_def,
213    const std::vector<std::string>& output_names, size_t max_batch_size,
214    size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
215
216  // optimization pass
217  tensorflow::grappler::GrapplerItem item;
218  item.fetch = output_names;
219  tensorflow::GraphDef gdef;
220
221  // layout optimization
222  item.graph = graph_def;
223  tensorflow::grappler::LayoutOptimizer optimizer;
224  tensorflow::grappler::Cluster* gCluster;
225
226  // virtual cluster
227  tensorflow::DeviceProperties device_properties;
228  device_properties.set_type("GPU");
229  device_properties.mutable_environment()->insert({"architecture", "6"});
230  gCluster =
231    new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
232
233  // single machine
234  int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
235  int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
236  LOG(DEBUG) << "cpu_cores: " << num_cpu_cores;
237  LOG(DEBUG) << "gpus: " << num_gpus;
238
239  tensorflow::Status status = optimizer.Optimize(gCluster, item, &gdef);
240
241  if (status !=tensorflow::Status::OK())
242    return status;
243
244  // constant folding
245  item.graph = gdef;
246  tensorflow::grappler::ConstantFolding fold(nullptr);
247  status = fold.Optimize(nullptr, item, &gdef);
248  if (status !=tensorflow::Status::OK())
249    return status;
250
251  // AJ refactoring shape inference through grappler/GraphProperties.
252  tensorflow::grappler::GraphProperties static_graph_properties(item);
253  static_graph_properties.InferStatically(false);
254
255  // Build full graph
256  tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
257                                             gdef.library());
258  tensorflow::Graph graph(flib);
259  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
260      tensorflow::GraphConstructorOptions(), gdef, &graph));
261
262  // Segment the graph into subgraphs that can be converted to TensorRT
263  tensorflow::tensorrt::segment::SegmentOptions segment_options;
264  // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
265  for (auto node : output_names) output_nodes.insert(node);
266
267  // TODO(sami): this should be passed as a knob!!!!
268  segment_options.minimum_segment_size = 2;
269  tensorflow::tensorrt::segment::SegmentNodesVector segments;
270  TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
271      gdef, IsTensorRTCandidate, segment_options, &segments));
272  if (segments.size() > 1) {
273    LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
274  }
275  std::unordered_map<std::string, tensorflow::Node*> node_map;
276  TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
277  for (const std::set<std::string> &subgraph_node_names : segments) {
278    std::set<int> subgraph_node_ids;
279    for (const std::string &node_name : subgraph_node_names) {
280      subgraph_node_ids.insert(node_map.at(node_name)->id());
281    }
282    TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
283        graph, output_names, subgraph_node_ids, max_batch_size,
284        max_workspace_size, static_graph_properties));
285  }
286  graph.ToGraphDef(new_graph_def);
287  return tensorflow::Status::OK();
288}
289
290}  // namespace convert
291}  // namespace tensorrt
292}  // namespace tensorflow
293