1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_GRAPPLER_UTILS_H_
17#define TENSORFLOW_GRAPPLER_UTILS_H_
18
19#include <functional>
20#include <unordered_map>
21#include <unordered_set>
22#include <vector>
23
24#include "tensorflow/core/framework/graph.pb.h"
25#include "tensorflow/core/framework/node_def.pb.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/core/threadpool.h"
29#include "tensorflow/core/lib/gtl/inlined_vector.h"
30
31namespace tensorflow {
32namespace grappler {
33
34// A utility class to lookup a node and its outputs by node name.
35class NodeMap {
36 public:
37  // Note: The NodeMap will store pointers to nodes in graph, which may become
38  // invalid if graph is changed.
39  explicit NodeMap(GraphDef* graph);
40  NodeDef* GetNode(const string& name) const;
41  bool NodeExists(const string& name) const;
42  const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
43  // This method doesn't record the outputs of the added node; the outputs need
44  // to be explicitly added by the AddOutput method.
45  void AddNode(const string& name, NodeDef* node);
46  void RemoveNode(const string& name);
47  void UpdateInput(const string& node_name, const string& old_input_name,
48                   const string& new_input_name);
49  void AddOutput(const string& node_name, const string& output_name);
50  void RemoveInputs(const string& node_name);
51  void RemoveOutput(const string& node_name, const string& output_name);
52  void RemoveOutputs(const string& node_name);
53  void UpdateOutput(const string& node_name, const string& old_output_name,
54                    const string& new_output_name);
55
56 private:
57  const std::set<NodeDef*> empty_set_;
58  std::unordered_map<string, NodeDef*> nodes_;
59  std::unordered_map<string, std::set<NodeDef*>> outputs_;
60};
61
62// A vector with a set. The set stores the same elements as the vector, and
63// quickly answers whether a value is in the vector. Duplicated elements are not
64// allowed for now.
65template <class T>
66class SetVector {
67 public:
68  // Returns false if value already existed in the set, true otherwise.
69  bool PushBack(const T& value) {
70    if (!set_.insert(value).second) {
71      return false;
72    }
73    vector_.push_back(value);
74    return true;
75  }
76
77  T PopBack() {
78    T back = vector_.back();
79    set_.erase(back);
80    vector_.pop_back();
81    return back;
82  }
83
84  bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
85
86  bool Empty() const { return vector_.empty(); }
87
88  void Reserve(int64 size) { vector_.reserve(size); }
89
90 private:
91  std::unordered_set<T> set_;
92  std::vector<T> vector_;
93};
94
95// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
96// the ^ character.
97bool IsControlInput(const string& name);
98
99// True iff 'name1' and 'name2' refer to the same input.
100bool IsSameInput(const string& name1, const string& name2);
101
102// Return the node name corresponding to 'name' if name is valid, or the empty
103// string otherwise.
104string NodeName(const string& name);
105
106// Get the trailing position number ":{digits}" (if any) of a node name.
107int NodePosition(const string& name);
108
109// Returns the node name and position in a single call.
110string ParseNodeName(const string& name, int* position);
111
112// Add a prefix to a node name with a custom delimiter.
113string AddPrefixToNodeName(const string& name, const string& prefix,
114                           const string& delimiter);
115
116// Add a prefix to a node name.
117string AddPrefixToNodeName(const string& name, const string& prefix);
118
119// Executes a 'fn' in the 'thread_pool'. The method waits for the configured
120// timeout (in milliseconds) for 'fn' to complete, before returning false.
121//
122// If returning false, the 'fn' may still continue to execute in the
123// thread-pool. It is the responsibility of the caller to reset the thread-pool
124// as appropriate.
125bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
126                        thread::ThreadPool* thread_pool);
127
128// Returns the node name prefixed with conventional symbol '^'
129// for control dependency, given a NodeDef.
130string AsControlDependency(const NodeDef& node);
131
132// Returns the node name prefixed with conventional symbol '^'
133// for control dependency, given a node name
134string AsControlDependency(const string& node);
135
136// Returns the number of outputs of a node according to its OpDef. Note that
137// some of the outputs may be unconnected.
138int NumOutputs(const NodeDef& node, GraphDef* graph);
139
140// Number of connected non-control inputs.
141int NumNonControlInputs(const NodeDef& node);
142
143// Number of connected non-control outputs.
144int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
145
146// Removes redundant control inputs from node.
147void DedupControlInputs(NodeDef* node);
148
149// Returns the data type in attribute `attr_name` of `node`. If that attribute
150// doesn't exist, returns DT_INVALID.
151DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);
152
153// Returns the last node in the simple chain starting at source and traversing
154// through the input(0) edge from each node as long as the next node satisfies
155// the predicate given in pred_fn. If no nodes satisfy the predicate, &source
156// will be returned. Example: For the chain
157//    source <- a <- b <- ... <- y <- z
158// where
159//    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
160//    pred_fn(z) = false,
161// the return value will be a pointer to y.
162NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
163                        bool follow_control_input,
164                        const std::function<bool(const NodeDef&)>& pred_fn);
165
166// Permute the nodes of graph in place according to the permutation.
167void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
168                         bool invert_permutation);
169
170class SimpleGraphView {
171 public:
172  Status Initialize(const GraphDef& graph) {
173    return Initialize(graph, true, true);
174  }
175  Status Initialize(const GraphDef& graph, bool dedup_inputs,
176                    bool dedup_outputs);
177
178  inline int num_nodes() const { return index_to_name_.size(); }
179  inline const int index(const string& node_name) const {
180    const auto& it = name_to_index_.find(node_name);
181    DCHECK(it != name_to_index_.end());
182    return it == name_to_index_.end() ? -1 : it->second;
183  }
184  inline const string& node_name(int node_idx) const {
185    return index_to_name_[node_idx];
186  }
187  inline const gtl::InlinedVector<int, 4>& inputs(int node_idx) const {
188    return inputs_[node_idx];
189  }
190  inline const gtl::InlinedVector<int, 2>& outputs(int node_idx) const {
191    return outputs_[node_idx];
192  }
193
194  string PrintToString() const;
195
196 private:
197  std::vector<string> index_to_name_;
198  std::unordered_map<string, int> name_to_index_;
199  std::vector<gtl::InlinedVector<int, 4>> inputs_;
200  std::vector<gtl::InlinedVector<int, 2>> outputs_;
201};
202
203}  // end namespace grappler
204}  // end namespace tensorflow
205
206#endif  // TENSORFLOW_GRAPPLER_UTILS_H_
207