convert_graph.cc revision 6908cc233c679b8fe61d99a30d3828362caf47be
1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" 17 18#include <list> 19#include <set> 20#include <sstream> 21#include <string> 22#include <unordered_map> 23#include <unordered_set> 24#include <vector> 25#include <map> 26#include <utility> 27 28#include "NvInfer.h" 29 30#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" 31#include "tensorflow/contrib/tensorrt/segment/segment.h" 32#include "tensorflow/core/framework/graph.pb.h" 33#include "tensorflow/core/framework/node_def.pb.h" 34#include "tensorflow/core/graph/algorithm.h" 35#include "tensorflow/core/graph/graph.h" 36#include "tensorflow/core/graph/graph_constructor.h" 37#include "tensorflow/core/lib/core/errors.h" 38#include "tensorflow/core/lib/core/status.h" 39#include "tensorflow/core/platform/logging.h" 40 41#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1) 42#include "tensorflow/core/grappler/optimizers/constant_folding.h" 43#include "tensorflow/core/grappler/optimizers/layout_optimizer.h" 44#include "tensorflow/core/grappler/devices.h" 45#include "tensorflow/core/grappler/clusters/virtual_cluster.h" 46#include "tensorflow/core/protobuf/device_properties.pb.h" 47#include "tensorflow/core/grappler/grappler_item.h" 48#include "tensorflow/core/grappler/utils.h" 49 50#include "tensorflow/core/grappler/costs/graph_properties.h" 51 52//------------------------------------------------------------------------------ 53namespace tensorflow { 54namespace tensorrt { 55namespace convert { 56namespace { 57 58static std::unordered_set<std::string> output_nodes; 59bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) { 60// LINT.IfChange 61 static const std::set<std::string> candidate_ops = { 62 "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu", 63 "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean" 64 // TODO(ben,jie): ... 65 }; 66// LINT.ThenChange( 67// https://www.tensorflow.org/code/tensorflow/contrib/tensorrt/convert/convert_nodes.h) 68 69 if (output_nodes.count(node_def.name())) return false; 70 return candidate_ops.count(node_def.op()); 71} 72 73void GetSubGraphIncomingEdges(const tensorflow::Graph & graph, 74 const std::set<int> &subgraph_node_ids, 75 tensorflow::EdgeSet* incoming_edges) { 76 for (int node_id : subgraph_node_ids) { 77 tensorflow::Node const* node = graph.FindNodeId(node_id); 78 LOG(DEBUG) << node->name() << " has incoming edges: "; 79 for (tensorflow::Edge const* edge : node->in_edges()) { 80 if (!subgraph_node_ids.count(edge->src()->id()) && 81 !edge->src()->IsSource()) { 82 LOG(DEBUG) << edge->src()->name() << ", "; 83 incoming_edges->insert(edge); 84 } 85 } 86 } 87} 88 89void GetSubGraphOutgoingEdges(const tensorflow::Graph &graph, 90 const std::set<int> &subgraph_node_ids, 91 tensorflow::EdgeSet* outgoing_edges) { 92 for (int node_id : subgraph_node_ids) { 93 tensorflow::Node const* node = graph.FindNodeId(node_id); 94 LOG(DEBUG) << node->name() << " has outgoing edges: "; 95 for (tensorflow::Edge const* edge : node->out_edges()) { 96 if (!subgraph_node_ids.count(edge->dst()->id()) && 97 !edge->dst()->IsSink()) { 98 outgoing_edges->insert(edge); 99 } 100 } 101 } 102} 103 104std::pair<std::string, int> ParseTensorName(std::string name, 105 int default_idx = 0) { 106 int idx = default_idx; 107 size_t sep = name.find_last_of(':'); 108 if (sep != std::string::npos) { 109 name = name.substr(0, sep); 110 idx = std::stoi(name.substr(sep + 1)); 111 } 112 return std::make_pair(name, idx); 113} 114 115std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap( 116 const std::vector<std::string>& tensor_names) { 117 std::unordered_map<std::string, std::vector<int>> result; 118 for (std::string const& tensor_name : tensor_names) { 119 std::string node_name; 120 int index; 121 std::tie(node_name, index) = ParseTensorName(tensor_name); 122 result[node_name].push_back(index); 123 } 124 return result; 125} 126 127tensorflow::Status ConvertSubGraphToTensorRT( 128 tensorflow::Graph& graph, const std::vector<std::string>& output_names, 129 const std::set<int>& subgraph_node_ids, 130 size_t max_batch_size, // max batch size that engine will be created for 131 // max amount of memory that engine will be allowed to consume, in bytes 132 size_t max_workspace_size, 133 const tensorflow::grappler::GraphProperties& graph_properties) { 134 tensorflow::EdgeSet subgraph_incoming_edges; 135 GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges); 136 137 std::vector<std::pair<int, int>> subgraph_inputs; 138 139 140 // Collect inputs by looking for incoming edges 141 for (tensorflow::Edge const* edge : subgraph_incoming_edges) { 142 subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); 143 } 144 std::set<std::pair<int, int>> subgraph_outputs_set; 145 // Collect outputs referenced from output_names 146 auto output_name_to_index_map = BuildTensorNameMap(output_names); 147 for (int node_id : subgraph_node_ids) { 148 tensorflow::Node* node = graph.FindNodeId(node_id); 149 if (output_name_to_index_map.count(node->name())) { 150 for (int index : output_name_to_index_map.at(node->name())) { 151 subgraph_outputs_set.insert({node_id, index}); 152 } 153 } 154 } 155 // Collect outputs referenced from outgoing edges 156 tensorflow::EdgeSet subgraph_outgoing_edges; 157 GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges); 158 for (tensorflow::Edge const* edge : subgraph_outgoing_edges) { 159 subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); 160 } 161 // Impose an ordering on the outputs 162 std::vector<std::pair<int, int>> subgraph_outputs( 163 subgraph_outputs_set.begin(), subgraph_outputs_set.end()); 164 // Build TensorRT node and add it to the graph 165 tensorflow::NodeDef trt_node_def; 166 TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef( 167 graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs, 168 max_batch_size, max_workspace_size, graph_properties, &trt_node_def)); 169 tensorflow::Status status; 170 tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status); 171 172 TF_RETURN_IF_ERROR(status); 173 174 // Re-map outgoing edges to use the new TRT node instead of the orig subgraph 175 std::map<std::pair<int, int>, int> subgraph_edge_to_output_map; 176 for (size_t i = 0; i < subgraph_outputs.size(); ++i) { 177 subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i}); 178 } 179 TF_RETURN_IF_ERROR(status); 180 for (tensorflow::Edge const* edge : subgraph_outgoing_edges) { 181 std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()}; 182 int new_src_output = subgraph_edge_to_output_map.at(old_src); 183 graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input()); 184 } 185 // Remove the original subgraph 186 for (int node_id : subgraph_node_ids) { 187 tensorflow::Node* node = graph.FindNodeId(node_id); 188 // Don't remove the input placeholders 189 if (node->type_string() == "Placeholder") { 190 continue; 191 } 192 graph.RemoveNode(node); 193 } 194 return tensorflow::Status::OK(); 195} 196 197tensorflow::Status BuildNodeMap( 198 const tensorflow::Graph& graph, 199 std::unordered_map<std::string, tensorflow::Node*>* node_map) { 200 for (auto* node : graph.op_nodes()) { 201 if (!node_map->insert({node->name(), node}).second) { 202 return tensorflow::errors::AlreadyExists( 203 "Node name is not unique in graph: " + node->name()); 204 } 205 } 206 return tensorflow::Status::OK(); 207} 208 209} // namespace 210 211tensorflow::Status ConvertGraphDefToTensorRT( 212 const tensorflow::GraphDef& graph_def, 213 const std::vector<std::string>& output_names, size_t max_batch_size, 214 size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) { 215 216 // optimization pass 217 tensorflow::grappler::GrapplerItem item; 218 item.fetch = output_names; 219 tensorflow::GraphDef gdef; 220 221 // layout optimization 222 item.graph = graph_def; 223 tensorflow::grappler::LayoutOptimizer optimizer; 224 tensorflow::grappler::Cluster* gCluster; 225 226 // virtual cluster 227 tensorflow::DeviceProperties device_properties; 228 device_properties.set_type("GPU"); 229 device_properties.mutable_environment()->insert({"architecture", "6"}); 230 gCluster = 231 new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); 232 233 // single machine 234 int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); 235 int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); 236 LOG(DEBUG) << "cpu_cores: " << num_cpu_cores; 237 LOG(DEBUG) << "gpus: " << num_gpus; 238 239 tensorflow::Status status = optimizer.Optimize(gCluster, item, &gdef); 240 241 if (status !=tensorflow::Status::OK()) 242 return status; 243 244 // constant folding 245 item.graph = gdef; 246 tensorflow::grappler::ConstantFolding fold(nullptr); 247 status = fold.Optimize(nullptr, item, &gdef); 248 if (status !=tensorflow::Status::OK()) 249 return status; 250 251 // AJ refactoring shape inference through grappler/GraphProperties. 252 tensorflow::grappler::GraphProperties static_graph_properties(item); 253 static_graph_properties.InferStatically(false); 254 255 // Build full graph 256 tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), 257 gdef.library()); 258 tensorflow::Graph graph(flib); 259 TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( 260 tensorflow::GraphConstructorOptions(), gdef, &graph)); 261 262 // Segment the graph into subgraphs that can be converted to TensorRT 263 tensorflow::tensorrt::segment::SegmentOptions segment_options; 264 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) 265 for (auto node : output_names) output_nodes.insert(node); 266 267 // TODO(sami): this should be passed as a knob!!!! 268 segment_options.minimum_segment_size = 2; 269 tensorflow::tensorrt::segment::SegmentNodesVector segments; 270 TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( 271 gdef, IsTensorRTCandidate, segment_options, &segments)); 272 if (segments.size() > 1) { 273 LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); 274 } 275 std::unordered_map<std::string, tensorflow::Node*> node_map; 276 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); 277 for (const std::set<std::string> &subgraph_node_names : segments) { 278 std::set<int> subgraph_node_ids; 279 for (const std::string &node_name : subgraph_node_names) { 280 subgraph_node_ids.insert(node_map.at(node_name)->id()); 281 } 282 TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT( 283 graph, output_names, subgraph_node_ids, max_batch_size, 284 max_workspace_size, static_graph_properties)); 285 } 286 graph.ToGraphDef(new_graph_def); 287 return tensorflow::Status::OK(); 288} 289 290} // namespace convert 291} // namespace tensorrt 292} // namespace tensorflow 293