1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4vcyou may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_ 17#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_ 18 19#include <array> 20#include <unordered_map> 21#include <unordered_set> 22#include <vector> 23 24#include "tensorflow/core/common_runtime/shape_refiner.h" 25#include "tensorflow/core/framework/graph.pb.h" 26#include "tensorflow/core/framework/graph_transfer_info.pb.h" 27#include "tensorflow/core/framework/shape_inference.h" 28#include "tensorflow/core/graph/graph.h" 29#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h" 30#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 31#include "tensorflow/core/platform/macros.h" 32#include "tensorflow/core/platform/protobuf.h" 33#include "tensorflow/core/util/padding.h" 34 35namespace tensorflow { 36 37// GraphTransferer transfers graph definitions into SoC memory. 38// This functionality is effective if SoC is capable to run 39// the graph on that chip. 40// TODO(satok): support transferring subgraphs to be able to split graphs 41// to avoid unsupported ops in SoC. 42class GraphTransferer { 43 public: 44 // TODO(satok): Remove. Use proto definition instead. 45 static constexpr int MAX_SUPPORTED_RANK = 4; 46 // TODO(satok): Remove. Use proto definition instead. 47 static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK; 48 using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap; 49 50 GraphTransferer() = default; 51 52 // Load graph structure into GraphTransferer 53 // TODO(satok): Pass a pair of TensorShape and DataType instead of 54 // Tensor as input_node_info_list. 55 Status LoadGraphFromProto( 56 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 57 const GraphDef& graph_def, 58 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 59 const std::vector<string>& output_node_names, 60 const bool shape_inference_for_unknown_shape); 61 62 // Load graph structure into GraphTransferer from protobuf file 63 // TODO(satok): Pass a pair of TensorShape and DataType instead of 64 // Tensor as input_node_info_list. 65 Status LoadGraphFromProtoFile( 66 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 67 const string& graph_def_path, 68 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 69 const std::vector<string>& output_node_names, const bool is_text_proto, 70 const bool shape_inference_for_unknown_shape, 71 const bool dry_run_for_unknown_shape); 72 73 // Sort params so that all input nodes appear before consumer nodes. 74 // CAVEAT: This may be slow if the number of nodes are too large 75 void SortParams(const std::vector<string>& output_node_names); 76 77 void EnableStrictCheckMode(bool enable); 78 79 // Import parameters for transfer 80 void SetSerializedGraphTransferInfo(const string& serialized_proto); 81 82 // Return parameters for graph transfer 83 const GraphTransferInfo& GetGraphTransferInfo() const; 84 85 // Return mutable GraphTransferInfo for graph transfer 86 GraphTransferInfo& GetMutableGraphTransferInfo(); 87 88 // Dump verification string of parameters to verify with offline tools 89 void DumpVerificationStringOfNodeTransferParams() const; 90 91 static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray( 92 const TensorShape& shape); 93 94 private: 95 class TransferParamsComparator { 96 public: 97 TransferParamsComparator( 98 const std::unordered_map<int, std::unordered_set<int>>& dep_map); 99 bool operator()(const GraphTransferInfo::NodeInfo& obj0, 100 const GraphTransferInfo::NodeInfo& obj1); 101 const std::unordered_map<int, std::unordered_set<int>>& dependency_map_; 102 }; 103 104 void CacheNode(const Node& node); 105 106 bool AreAllInputsCached(const Node& node) const; 107 108 // Transform a remote fused graph to add an aggregated input node which takes 109 // all inputs of the remote graph. 110 Status TransformGraphToAddAggregatedInputNode( 111 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 112 Graph* graph, ShapeRefiner* shape_refiner); 113 114 Status RegisterNode( 115 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 116 const ShapeRefiner& shape_refiner, const Node& node, 117 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 118 const std::vector<string>& output_node_names); 119 120 void RegisterConstantNode(const ShapeRefiner& shape_refiner, 121 const Node& node); 122 123 int RegisterConstantShape(const std::vector<int>& shape); 124 125 int RegisterConstTensor(const Tensor& tensor, const string& suffix); 126 127 int RegisterConstScalar(const DataType dt, const int val, const int dst_id, 128 const int dst_input_count); 129 130 bool HasPaddingAndStrides(const Node& node); 131 132 bool NeedsToAddRank(const Node& node); 133 134 bool IsPadNode(const Node& node); 135 136 // Return true if the node is a reshape op which just flattens input 137 // TODO(satok): Remove this method once generic reshape op is implemented in 138 // SOC 139 bool IsNodeFlattenReshape(const Node& node, 140 const ShapeRefiner& shape_refiner); 141 142 void RegisterNodeWithPaddingAndStrides( 143 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 144 const ShapeRefiner& shape_refiner, const Node& node); 145 146 void RegisterNodeWithRank( 147 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 148 const ShapeRefiner& shape_refiner, const Node& node); 149 150 void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions, 151 const ShapeRefiner& shape_refiner, const Node& node); 152 153 void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions, 154 const ShapeRefiner& shape_refiner, const Node& node); 155 156 void RegisterFlattenNode( 157 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 158 const ShapeRefiner& shape_refiner, const Node& node); 159 160 void RegisterGenericNode( 161 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 162 const ShapeRefiner& shape_refiner, const Node& node); 163 164 Status RegisterNodeIfAllInputsAreCached( 165 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 166 const ShapeRefiner& shape_refiner, const Node& node, 167 const bool only_register_const_node, 168 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 169 const std::vector<string>& output_node_names); 170 171 void AppendNodeParams(const string& name, const int id, const string& type, 172 const int type_id, const int padding, 173 const int inputs_size, 174 const std::vector<int>& extra_inputs, 175 const int outputs_size); 176 177 void AddNodeInputByInputIndex( 178 const Node& node, const int idx, 179 GraphTransferInfo::NodeInputInfo* node_input_info); 180 181 void AppendNodeInputParams(const int id, const Node& node, 182 const std::vector<int>& extra_inputs); 183 184 void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id, 185 const Node& node); 186 187 static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray( 188 const shape_inference::ShapeHandle& shape_handle, 189 shape_inference::InferenceContext* context); 190 191 void AppendNodeParamsWithIoParams( 192 const ShapeRefiner& shape_refiner, const Node& node, const string& name, 193 const int id, const string& type, const int type_id, const int padding, 194 const int inputs_size, const std::vector<int>& extra_inputs, 195 const int outputs_size, const bool append_input_params, 196 const bool append_output_params); 197 198 static string ToPaddingDebugString(int padding); 199 200 // Create dependency map 201 static void FillDependencyRec( 202 int node_id, std::unordered_map<int, std::unordered_set<int>>& dep_map, 203 std::unordered_set<int>& completed); 204 205 // Build tensor from proto 206 static Status MakeTensorFromProto(const TensorProto& tensor_proto, 207 Tensor* tensor); 208 209 void ClearCache(); 210 211 // Dump pretty print of parameters 212 void DumpNodeTransferParams() const; 213 214 GraphTransferInfo graph_transfer_info_{}; 215 216 std::vector<const Node*> node_name_cache_list_{}; 217 std::unordered_map<string, int> node_name_to_id_cache_map_{}; 218 219 // strict check mode is true by default. Disable this if the ops' shape 220 // inferences are not implemented correctly. 221 bool strict_check_mode_{true}; 222 223 TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer); 224}; 225 226} // namespace tensorflow 227 228#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H 229