1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_GRAPPLER_UTILS_H_ 17#define TENSORFLOW_GRAPPLER_UTILS_H_ 18 19#include <functional> 20#include <unordered_map> 21#include <unordered_set> 22#include <vector> 23 24#include "tensorflow/core/framework/graph.pb.h" 25#include "tensorflow/core/framework/node_def.pb.h" 26#include "tensorflow/core/framework/types.h" 27#include "tensorflow/core/lib/core/status.h" 28#include "tensorflow/core/lib/core/threadpool.h" 29#include "tensorflow/core/lib/gtl/inlined_vector.h" 30 31namespace tensorflow { 32namespace grappler { 33 34// A utility class to lookup a node and its outputs by node name. 35class NodeMap { 36 public: 37 // Note: The NodeMap will store pointers to nodes in graph, which may become 38 // invalid if graph is changed. 39 explicit NodeMap(GraphDef* graph); 40 NodeDef* GetNode(const string& name) const; 41 bool NodeExists(const string& name) const; 42 const std::set<NodeDef*>& GetOutputs(const string& node_name) const; 43 // This method doesn't record the outputs of the added node; the outputs need 44 // to be explicitly added by the AddOutput method. 45 void AddNode(const string& name, NodeDef* node); 46 void RemoveNode(const string& name); 47 void UpdateInput(const string& node_name, const string& old_input_name, 48 const string& new_input_name); 49 void AddOutput(const string& node_name, const string& output_name); 50 void RemoveInputs(const string& node_name); 51 void RemoveOutput(const string& node_name, const string& output_name); 52 void RemoveOutputs(const string& node_name); 53 void UpdateOutput(const string& node_name, const string& old_output_name, 54 const string& new_output_name); 55 56 private: 57 const std::set<NodeDef*> empty_set_; 58 std::unordered_map<string, NodeDef*> nodes_; 59 std::unordered_map<string, std::set<NodeDef*>> outputs_; 60}; 61 62// A vector with a set. The set stores the same elements as the vector, and 63// quickly answers whether a value is in the vector. Duplicated elements are not 64// allowed for now. 65template <class T> 66class SetVector { 67 public: 68 // Returns false if value already existed in the set, true otherwise. 69 bool PushBack(const T& value) { 70 if (!set_.insert(value).second) { 71 return false; 72 } 73 vector_.push_back(value); 74 return true; 75 } 76 77 T PopBack() { 78 T back = vector_.back(); 79 set_.erase(back); 80 vector_.pop_back(); 81 return back; 82 } 83 84 bool Exists(const T& value) const { return set_.find(value) != set_.end(); } 85 86 bool Empty() const { return vector_.empty(); } 87 88 void Reserve(int64 size) { vector_.reserve(size); } 89 90 private: 91 std::unordered_set<T> set_; 92 std::vector<T> vector_; 93}; 94 95// True iff 'name' refers to a control inputs, i.e. a node name prefixed with 96// the ^ character. 97bool IsControlInput(const string& name); 98 99// True iff 'name1' and 'name2' refer to the same input. 100bool IsSameInput(const string& name1, const string& name2); 101 102// Return the node name corresponding to 'name' if name is valid, or the empty 103// string otherwise. 104string NodeName(const string& name); 105 106// Get the trailing position number ":{digits}" (if any) of a node name. 107int NodePosition(const string& name); 108 109// Returns the node name and position in a single call. 110string ParseNodeName(const string& name, int* position); 111 112// Add a prefix to a node name with a custom delimiter. 113string AddPrefixToNodeName(const string& name, const string& prefix, 114 const string& delimiter); 115 116// Add a prefix to a node name. 117string AddPrefixToNodeName(const string& name, const string& prefix); 118 119// Executes a 'fn' in the 'thread_pool'. The method waits for the configured 120// timeout (in milliseconds) for 'fn' to complete, before returning false. 121// 122// If returning false, the 'fn' may still continue to execute in the 123// thread-pool. It is the responsibility of the caller to reset the thread-pool 124// as appropriate. 125bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms, 126 thread::ThreadPool* thread_pool); 127 128// Returns the node name prefixed with conventional symbol '^' 129// for control dependency, given a NodeDef. 130string AsControlDependency(const NodeDef& node); 131 132// Returns the node name prefixed with conventional symbol '^' 133// for control dependency, given a node name 134string AsControlDependency(const string& node); 135 136// Returns the number of outputs of a node according to its OpDef. Note that 137// some of the outputs may be unconnected. 138int NumOutputs(const NodeDef& node, GraphDef* graph); 139 140// Number of connected non-control inputs. 141int NumNonControlInputs(const NodeDef& node); 142 143// Number of connected non-control outputs. 144int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); 145 146// Removes redundant control inputs from node. 147void DedupControlInputs(NodeDef* node); 148 149// Returns the data type in attribute `attr_name` of `node`. If that attribute 150// doesn't exist, returns DT_INVALID. 151DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name); 152 153// Returns the last node in the simple chain starting at source and traversing 154// through the input(0) edge from each node as long as the next node satisfies 155// the predicate given in pred_fn. If no nodes satisfy the predicate, &source 156// will be returned. Example: For the chain 157// source <- a <- b <- ... <- y <- z 158// where 159// pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true, 160// pred_fn(z) = false, 161// the return value will be a pointer to y. 162NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, 163 bool follow_control_input, 164 const std::function<bool(const NodeDef&)>& pred_fn); 165 166// Permute the nodes of graph in place according to the permutation. 167void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, 168 bool invert_permutation); 169 170class SimpleGraphView { 171 public: 172 Status Initialize(const GraphDef& graph) { 173 return Initialize(graph, true, true); 174 } 175 Status Initialize(const GraphDef& graph, bool dedup_inputs, 176 bool dedup_outputs); 177 178 inline int num_nodes() const { return index_to_name_.size(); } 179 inline const int index(const string& node_name) const { 180 const auto& it = name_to_index_.find(node_name); 181 DCHECK(it != name_to_index_.end()); 182 return it == name_to_index_.end() ? -1 : it->second; 183 } 184 inline const string& node_name(int node_idx) const { 185 return index_to_name_[node_idx]; 186 } 187 inline const gtl::InlinedVector<int, 4>& inputs(int node_idx) const { 188 return inputs_[node_idx]; 189 } 190 inline const gtl::InlinedVector<int, 2>& outputs(int node_idx) const { 191 return outputs_[node_idx]; 192 } 193 194 string PrintToString() const; 195 196 private: 197 std::vector<string> index_to_name_; 198 std::unordered_map<string, int> name_to_index_; 199 std::vector<gtl::InlinedVector<int, 4>> inputs_; 200 std::vector<gtl::InlinedVector<int, 2>> outputs_; 201}; 202 203} // end namespace grappler 204} // end namespace tensorflow 205 206#endif // TENSORFLOW_GRAPPLER_UTILS_H_ 207