1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/algorithm.h"
17
18#include <algorithm>
19#include <deque>
20#include <vector>
21
22#include "tensorflow/core/platform/logging.h"
23
24namespace tensorflow {
25
26void DFS(const Graph& g, const std::function<void(Node*)>& enter,
27         const std::function<void(Node*)>& leave,
28         const NodeComparator& stable_comparator) {
29  // Stack of work to do.
30  struct Work {
31    Node* node;
32    bool leave;  // Are we entering or leaving n?
33  };
34  std::vector<Work> stack;
35  stack.push_back(Work{g.source_node(), false});
36
37  std::vector<bool> visited(g.num_node_ids(), false);
38  while (!stack.empty()) {
39    Work w = stack.back();
40    stack.pop_back();
41
42    Node* n = w.node;
43    if (w.leave) {
44      leave(n);
45      continue;
46    }
47
48    if (visited[n->id()]) continue;
49    visited[n->id()] = true;
50    if (enter) enter(n);
51
52    // Arrange to call leave(n) when all done with descendants.
53    if (leave) stack.push_back(Work{n, true});
54
55    gtl::iterator_range<NeighborIter> nodes = n->out_nodes();
56    auto add_work = [&visited, &stack](Node* out) {
57      if (!visited[out->id()]) {
58        // Note; we must not mark as visited until we actually process it.
59        stack.push_back(Work{out, false});
60      }
61    };
62
63    if (stable_comparator) {
64      std::vector<Node*> nodes_sorted;
65      for (Node* out : nodes) {
66        nodes_sorted.emplace_back(out);
67      }
68      std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
69      for (Node* out : nodes_sorted) {
70        add_work(out);
71      }
72    } else {
73      for (Node* out : nodes) {
74        add_work(out);
75      }
76    }
77  }
78}
79
80void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
81                const std::function<void(Node*)>& leave,
82                const NodeComparator& stable_comparator) {
83  ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
84}
85
86namespace {
87
88template <typename T>
89void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
90                          const std::function<void(T)>& enter,
91                          const std::function<void(T)>& leave,
92                          const NodeComparator& stable_comparator) {
93  // Stack of work to do.
94  struct Work {
95    T node;
96    bool leave;  // Are we entering or leaving n?
97  };
98  std::vector<Work> stack(start.size());
99  for (int i = 0; i < start.size(); ++i) {
100    stack[i] = Work{start[i], false};
101  }
102
103  std::vector<bool> visited(g.num_node_ids(), false);
104  while (!stack.empty()) {
105    Work w = stack.back();
106    stack.pop_back();
107
108    T n = w.node;
109    if (w.leave) {
110      leave(n);
111      continue;
112    }
113
114    if (visited[n->id()]) continue;
115    visited[n->id()] = true;
116    if (enter) enter(n);
117
118    // Arrange to call leave(n) when all done with descendants.
119    if (leave) stack.push_back(Work{n, true});
120
121    gtl::iterator_range<NeighborIter> nodes = n->in_nodes();
122
123    auto add_work = [&visited, &stack](T out) {
124      if (!visited[out->id()]) {
125        // Note; we must not mark as visited until we actually process it.
126        stack.push_back(Work{out, false});
127      }
128    };
129
130    if (stable_comparator) {
131      std::vector<T> nodes_sorted;
132      for (T in : nodes) {
133        nodes_sorted.emplace_back(in);
134      }
135      std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
136      for (T in : nodes_sorted) {
137        add_work(in);
138      }
139    } else {
140      for (T in : nodes) {
141        add_work(in);
142      }
143    }
144  }
145}
146
147}  // namespace
148
149void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
150                    const std::function<void(const Node*)>& enter,
151                    const std::function<void(const Node*)>& leave,
152                    const NodeComparator& stable_comparator) {
153  ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
154}
155
156void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
157                    const std::function<void(Node*)>& enter,
158                    const std::function<void(Node*)>& leave,
159                    const NodeComparator& stable_comparator) {
160  ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
161}
162
163void GetPostOrder(const Graph& g, std::vector<Node*>* order,
164                  const NodeComparator& stable_comparator) {
165  order->clear();
166  DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator);
167}
168
169void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
170                         const NodeComparator& stable_comparator) {
171  GetPostOrder(g, order, stable_comparator);
172  std::reverse(order->begin(), order->end());
173}
174
175bool PruneForReverseReachability(Graph* g,
176                                 std::unordered_set<const Node*> visited) {
177  // Compute set of nodes that we need to traverse in order to reach
178  // the nodes in "nodes" by performing a breadth-first search from those
179  // nodes, and accumulating the visited nodes.
180  std::deque<const Node*> queue;
181  for (const Node* n : visited) {
182    VLOG(2) << "Reverse reach init: " << n->name();
183    queue.push_back(n);
184  }
185  while (!queue.empty()) {
186    const Node* n = queue.front();
187    queue.pop_front();
188    for (const Node* in : n->in_nodes()) {
189      if (visited.insert(in).second) {
190        queue.push_back(in);
191        VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
192      }
193    }
194  }
195
196  // Make a pass over the graph to remove nodes not in "visited"
197  std::vector<Node*> all_nodes;
198  all_nodes.reserve(g->num_nodes());
199  for (Node* n : g->nodes()) {
200    all_nodes.push_back(n);
201  }
202
203  bool any_removed = false;
204  for (Node* n : all_nodes) {
205    if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
206      g->RemoveNode(n);
207      any_removed = true;
208    }
209  }
210
211  return any_removed;
212}
213
214bool FixupSourceAndSinkEdges(Graph* g) {
215  // Connect all nodes with no incoming edges to source.
216  // Connect all nodes with no outgoing edges to sink.
217  bool changed = false;
218  for (Node* n : g->nodes()) {
219    if (!n->IsSource() && n->in_edges().empty()) {
220      g->AddControlEdge(g->source_node(), n);
221      changed = true;
222    }
223    if (!n->IsSink() && n->out_edges().empty()) {
224      g->AddControlEdge(n, g->sink_node());
225      changed = true;
226    }
227  }
228  return changed;
229}
230
231}  // namespace tensorflow
232