convert_graph.cc revision 599eadc299ae680bfb569ace4278b2eb262ecc44
1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" 17 18#include <list> 19#include <map> 20#include <set> 21#include <sstream> 22#include <string> 23#include <unordered_map> 24#include <unordered_set> 25#include <utility> 26#include <vector> 27 28#include "tensorflow/core/framework/graph.pb.h" 29#include "tensorflow/core/framework/node_def.pb.h" 30#include "tensorflow/core/graph/algorithm.h" 31#include "tensorflow/core/graph/graph.h" 32#include "tensorflow/core/graph/graph_constructor.h" 33#include "tensorflow/core/grappler/clusters/virtual_cluster.h" 34#include "tensorflow/core/grappler/costs/graph_properties.h" 35#include "tensorflow/core/grappler/devices.h" 36#include "tensorflow/core/grappler/grappler_item.h" 37#include "tensorflow/core/grappler/optimizers/constant_folding.h" 38#include "tensorflow/core/grappler/optimizers/layout_optimizer.h" 39#include "tensorflow/core/grappler/utils.h" 40#include "tensorflow/core/lib/core/errors.h" 41#include "tensorflow/core/lib/core/status.h" 42#include "tensorflow/core/platform/logging.h" 43#include "tensorflow/core/protobuf/device_properties.pb.h" 44 45#if GOOGLE_CUDA 46#if GOOGLE_TENSORRT 47#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" 48#include "tensorflow/contrib/tensorrt/segment/segment.h" 49#include "tensorrt/include/NvInfer.h" 50 51//------------------------------------------------------------------------------ 52namespace tensorflow { 53namespace tensorrt { 54namespace convert { 55namespace { 56 57static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) { 58 // LINT.IfChange 59 // TODO(jie): Segmentation shouldn't associated with op name. 60 // Split it into a registration for each kernel. 61 static const std::set<std::string> candidate_ops = { 62 "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu", 63 "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean" 64 }; 65 // LINT.ThenChange( 66 // https://www.tensorflow.org/code/tensorflow/contrib/tensorrt/convert/convert_nodes.h) 67 return candidate_ops.count(node_def.op()); 68} 69 70void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, 71 const std::set<int>& subgraph_node_ids, 72 tensorflow::EdgeSet* incoming_edges) { 73 for (int node_id : subgraph_node_ids) { 74 const tensorflow::Node* node = graph.FindNodeId(node_id); 75 for (const tensorflow::Edge* edge : node->in_edges()) { 76 if (!subgraph_node_ids.count(edge->src()->id()) && 77 !edge->src()->IsSource()) { 78 incoming_edges->insert(edge); 79 } 80 } 81 } 82} 83 84void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, 85 const std::set<int>& subgraph_node_ids, 86 tensorflow::EdgeSet* outgoing_edges) { 87 for (int node_id : subgraph_node_ids) { 88 const tensorflow::Node* node = graph.FindNodeId(node_id); 89 for (const tensorflow::Edge* edge : node->out_edges()) { 90 if (!subgraph_node_ids.count(edge->dst()->id()) && 91 !edge->dst()->IsSink()) { 92 outgoing_edges->insert(edge); 93 } 94 } 95 } 96} 97 98std::pair<std::string, int> ParseTensorName(std::string name, 99 int default_idx = 0) { 100 int idx = default_idx; 101 size_t sep = name.find_last_of(':'); 102 if (sep != std::string::npos) { 103 name = name.substr(0, sep); 104 idx = std::stoi(name.substr(sep + 1)); 105 } 106 return std::make_pair(name, idx); 107} 108 109std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap( 110 const std::vector<std::string>& tensor_names) { 111 std::unordered_map<std::string, std::vector<int>> result; 112 for (std::string const& tensor_name : tensor_names) { 113 std::string node_name; 114 int index; 115 std::tie(node_name, index) = ParseTensorName(tensor_name); 116 result[node_name].push_back(index); 117 } 118 return result; 119} 120 121tensorflow::Status ConvertSubGraphToTensorRT( 122 const std::vector<std::string>& output_names, 123 const std::set<int>& subgraph_node_ids, 124 size_t max_batch_size, // max batch size that engine will be created for 125 // max amount of memory that engine will be allowed to consume, in bytes 126 size_t max_workspace_size, 127 const tensorflow::grappler::GraphProperties& graph_properties, 128 tensorflow::Graph* graph) { 129 tensorflow::EdgeSet subgraph_incoming_edges; 130 GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges); 131 132 std::vector<std::pair<int, int>> subgraph_inputs; 133 134 // Collect inputs by looking for incoming edges 135 for (const tensorflow::Edge* edge : subgraph_incoming_edges) { 136 subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); 137 } 138 std::set<std::pair<int, int>> subgraph_outputs_set; 139 // Collect outputs referenced from output_names 140 auto output_name_to_index_map = BuildTensorNameMap(output_names); 141 for (int node_id : subgraph_node_ids) { 142 tensorflow::Node* node = graph->FindNodeId(node_id); 143 if (output_name_to_index_map.count(node->name())) { 144 for (int index : output_name_to_index_map.at(node->name())) { 145 subgraph_outputs_set.insert({node_id, index}); 146 } 147 } 148 } 149 // Collect outputs referenced from outgoing edges 150 tensorflow::EdgeSet subgraph_outgoing_edges; 151 GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges); 152 for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { 153 subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); 154 } 155 // Impose an ordering on the outputs 156 std::vector<std::pair<int, int>> subgraph_outputs( 157 subgraph_outputs_set.begin(), subgraph_outputs_set.end()); 158 // Build TensorRT node and add it to the graph 159 tensorflow::NodeDef trt_node_def; 160 TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef( 161 *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs, 162 max_batch_size, max_workspace_size, graph_properties, &trt_node_def)); 163 tensorflow::Status status; 164 tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status); 165 166 TF_RETURN_IF_ERROR(status); 167 168 // Re-map outgoing edges to use the new TRT node instead of the orig subgraph 169 std::map<std::pair<int, int>, int> subgraph_edge_to_output_map; 170 for (size_t i = 0; i < subgraph_outputs.size(); ++i) { 171 subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i}); 172 } 173 TF_RETURN_IF_ERROR(status); 174 for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { 175 std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()}; 176 int new_src_output = subgraph_edge_to_output_map.at(old_src); 177 graph->UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input()); 178 } 179 // Remove the original subgraph 180 for (int node_id : subgraph_node_ids) { 181 tensorflow::Node* node = graph->FindNodeId(node_id); 182 // Don't remove the input placeholders 183 if (node->type_string() == "Placeholder") { 184 continue; 185 } 186 graph->RemoveNode(node); 187 } 188 return tensorflow::Status::OK(); 189} 190 191tensorflow::Status BuildNodeMap( 192 const tensorflow::Graph& graph, 193 std::unordered_map<std::string, tensorflow::Node*>* node_map) { 194 for (auto* node : graph.op_nodes()) { 195 if (!node_map->insert({node->name(), node}).second) { 196 return tensorflow::errors::AlreadyExists( 197 "Node name is not unique in graph: " + node->name()); 198 } 199 } 200 return tensorflow::Status::OK(); 201} 202 203} // namespace 204 205tensorflow::Status ConvertGraphDefToTensorRT( 206 const tensorflow::GraphDef& graph_def, 207 const std::vector<std::string>& output_names, size_t max_batch_size, 208 size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) { 209 // optimization pass 210 tensorflow::grappler::GrapplerItem item; 211 item.fetch = output_names; 212 tensorflow::GraphDef gdef; 213 214 // layout optimization 215 item.graph = graph_def; 216 tensorflow::grappler::LayoutOptimizer optimizer; 217 tensorflow::grappler::Cluster* cluster; 218 219 // virtual cluster 220 tensorflow::DeviceProperties device_properties; 221 device_properties.set_type("GPU"); 222 device_properties.mutable_environment()->insert({"architecture", "6"}); 223 cluster = 224 new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); 225 226 tensorflow::Status status = optimizer.Optimize(cluster, item, &gdef); 227 228 if (status != tensorflow::Status::OK()) return status; 229 230 // constant folding 231 item.graph = gdef; 232 tensorflow::grappler::ConstantFolding fold(nullptr); 233 status = fold.Optimize(nullptr, item, &gdef); 234 if (status != tensorflow::Status::OK()) return status; 235 236 // AJ refactoring shape inference through grappler/GraphProperties. 237 tensorflow::grappler::GraphProperties static_graph_properties(item); 238 static_graph_properties.InferStatically(false); 239 240 // Build full graph 241 tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), 242 gdef.library()); 243 tensorflow::Graph graph(flib); 244 TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( 245 tensorflow::GraphConstructorOptions(), gdef, &graph)); 246 247 // Segment the graph into subgraphs that can be converted to TensorRT 248 tensorflow::tensorrt::segment::SegmentOptions segment_options; 249 250 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) 251 for (auto node : output_names) { 252 segment_options.exclude_node_list.insert(node); 253 } 254 255 // TODO(sami): this should be passed as a knob!!!! 256 segment_options.minimum_segment_size = 2; 257 tensorflow::tensorrt::segment::SegmentNodesVector segments; 258 TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( 259 gdef, IsTensorRTCandidate, segment_options, &segments)); 260 if (segments.size() > 1) { 261 VLOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); 262 } 263 std::unordered_map<std::string, tensorflow::Node*> node_map; 264 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); 265 for (const std::set<std::string>& subgraph_node_names : segments) { 266 std::set<int> subgraph_node_ids; 267 for (const std::string& node_name : subgraph_node_names) { 268 subgraph_node_ids.insert(node_map.at(node_name)->id()); 269 } 270 TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT( 271 output_names, subgraph_node_ids, max_batch_size, max_workspace_size, 272 static_graph_properties, &graph)); 273 } 274 graph.ToGraphDef(new_graph_def); 275 return tensorflow::Status::OK(); 276} 277 278} // namespace convert 279} // namespace tensorrt 280} // namespace tensorflow 281 282#endif // GOOGLE_TENSORRT 283#endif // GOOGLE_CUDA 284