1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4vcyou may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
17#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
18
19#include <array>
20#include <unordered_map>
21#include <unordered_set>
22#include <vector>
23
24#include "tensorflow/core/common_runtime/shape_refiner.h"
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/graph_transfer_info.pb.h"
27#include "tensorflow/core/framework/shape_inference.h"
28#include "tensorflow/core/graph/graph.h"
29#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
30#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/platform/protobuf.h"
33#include "tensorflow/core/util/padding.h"
34
35namespace tensorflow {
36
37// GraphTransferer transfers graph definitions into SoC memory.
38// This functionality is effective if SoC is capable to run
39// the graph on that chip.
40// TODO(satok): support transferring subgraphs to be able to split graphs
41// to avoid unsupported ops in SoC.
42class GraphTransferer {
43 public:
44  // TODO(satok): Remove. Use proto definition instead.
45  static constexpr int MAX_SUPPORTED_RANK = 4;
46  // TODO(satok): Remove. Use proto definition instead.
47  static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
48  using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
49
50  GraphTransferer() = default;
51
52  // Load graph structure into GraphTransferer
53  // TODO(satok): Pass a pair of TensorShape and DataType instead of
54  // Tensor as input_node_info_list.
55  Status LoadGraphFromProto(
56      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
57      const GraphDef& graph_def,
58      const std::vector<std::pair<string, Tensor>>& input_node_info_list,
59      const std::vector<string>& output_node_names,
60      const bool shape_inference_for_unknown_shape);
61
62  // Load graph structure into GraphTransferer from protobuf file
63  // TODO(satok): Pass a pair of TensorShape and DataType instead of
64  // Tensor as input_node_info_list.
65  Status LoadGraphFromProtoFile(
66      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
67      const string& graph_def_path,
68      const std::vector<std::pair<string, Tensor>>& input_node_info_list,
69      const std::vector<string>& output_node_names, const bool is_text_proto,
70      const bool shape_inference_for_unknown_shape,
71      const bool dry_run_for_unknown_shape);
72
73  // Sort params so that all input nodes appear before consumer nodes.
74  // CAVEAT: This may be slow if the number of nodes are too large
75  void SortParams(const std::vector<string>& output_node_names);
76
77  void EnableStrictCheckMode(bool enable);
78
79  // Import parameters for transfer
80  void SetSerializedGraphTransferInfo(const string& serialized_proto);
81
82  // Return parameters for graph transfer
83  const GraphTransferInfo& GetGraphTransferInfo() const;
84
85  // Return mutable GraphTransferInfo for graph transfer
86  GraphTransferInfo& GetMutableGraphTransferInfo();
87
88  // Dump verification string of parameters to verify with offline tools
89  void DumpVerificationStringOfNodeTransferParams() const;
90
91  static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray(
92      const TensorShape& shape);
93
94 private:
95  class TransferParamsComparator {
96   public:
97    TransferParamsComparator(
98        const std::unordered_map<int, std::unordered_set<int>>& dep_map);
99    bool operator()(const GraphTransferInfo::NodeInfo& obj0,
100                    const GraphTransferInfo::NodeInfo& obj1);
101    const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
102  };
103
104  void CacheNode(const Node& node);
105
106  bool AreAllInputsCached(const Node& node) const;
107
108  // Transform a remote fused graph to add an aggregated input node which takes
109  // all inputs of the remote graph.
110  Status TransformGraphToAddAggregatedInputNode(
111      const std::vector<std::pair<string, Tensor>>& input_node_info_list,
112      Graph* graph, ShapeRefiner* shape_refiner);
113
114  Status RegisterNode(
115      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
116      const ShapeRefiner& shape_refiner, const Node& node,
117      const std::vector<std::pair<string, Tensor>>& input_node_info_list,
118      const std::vector<string>& output_node_names);
119
120  void RegisterConstantNode(const ShapeRefiner& shape_refiner,
121                            const Node& node);
122
123  int RegisterConstantShape(const std::vector<int>& shape);
124
125  int RegisterConstTensor(const Tensor& tensor, const string& suffix);
126
127  int RegisterConstScalar(const DataType dt, const int val, const int dst_id,
128                          const int dst_input_count);
129
130  bool HasPaddingAndStrides(const Node& node);
131
132  bool NeedsToAddRank(const Node& node);
133
134  bool IsPadNode(const Node& node);
135
136  // Return true if the node is a reshape op which just flattens input
137  // TODO(satok): Remove this method once generic reshape op is implemented in
138  // SOC
139  bool IsNodeFlattenReshape(const Node& node,
140                            const ShapeRefiner& shape_refiner);
141
142  void RegisterNodeWithPaddingAndStrides(
143      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
144      const ShapeRefiner& shape_refiner, const Node& node);
145
146  void RegisterNodeWithRank(
147      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
148      const ShapeRefiner& shape_refiner, const Node& node);
149
150  void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
151                       const ShapeRefiner& shape_refiner, const Node& node);
152
153  void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
154                         const ShapeRefiner& shape_refiner, const Node& node);
155
156  void RegisterFlattenNode(
157      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
158      const ShapeRefiner& shape_refiner, const Node& node);
159
160  void RegisterGenericNode(
161      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
162      const ShapeRefiner& shape_refiner, const Node& node);
163
164  Status RegisterNodeIfAllInputsAreCached(
165      const IRemoteFusedGraphOpsDefinitions& ops_definitions,
166      const ShapeRefiner& shape_refiner, const Node& node,
167      const bool only_register_const_node,
168      const std::vector<std::pair<string, Tensor>>& input_node_info_list,
169      const std::vector<string>& output_node_names);
170
171  void AppendNodeParams(const string& name, const int id, const string& type,
172                        const int type_id, const int padding,
173                        const int inputs_size,
174                        const std::vector<int>& extra_inputs,
175                        const int outputs_size);
176
177  void AddNodeInputByInputIndex(
178      const Node& node, const int idx,
179      GraphTransferInfo::NodeInputInfo* node_input_info);
180
181  void AppendNodeInputParams(const int id, const Node& node,
182                             const std::vector<int>& extra_inputs);
183
184  void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id,
185                              const Node& node);
186
187  static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
188      const shape_inference::ShapeHandle& shape_handle,
189      shape_inference::InferenceContext* context);
190
191  void AppendNodeParamsWithIoParams(
192      const ShapeRefiner& shape_refiner, const Node& node, const string& name,
193      const int id, const string& type, const int type_id, const int padding,
194      const int inputs_size, const std::vector<int>& extra_inputs,
195      const int outputs_size, const bool append_input_params,
196      const bool append_output_params);
197
198  static string ToPaddingDebugString(int padding);
199
200  // Create dependency map
201  static void FillDependencyRec(
202      int node_id, std::unordered_map<int, std::unordered_set<int>>& dep_map,
203      std::unordered_set<int>& completed);
204
205  // Build tensor from proto
206  static Status MakeTensorFromProto(const TensorProto& tensor_proto,
207                                    Tensor* tensor);
208
209  void ClearCache();
210
211  // Dump pretty print of parameters
212  void DumpNodeTransferParams() const;
213
214  GraphTransferInfo graph_transfer_info_{};
215
216  std::vector<const Node*> node_name_cache_list_{};
217  std::unordered_map<string, int> node_name_to_id_cache_map_{};
218
219  // strict check mode is true by default.  Disable this if the ops' shape
220  // inferences are not implemented correctly.
221  bool strict_check_mode_{true};
222
223  TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer);
224};
225
226}  // namespace tensorflow
227
228#endif  // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H
229