1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/algorithm.h"
17
18#include <string>
19#include <vector>
20
21#include "tensorflow/core/graph/graph.h"
22#include "tensorflow/core/graph/graph_def_builder.h"
23#include "tensorflow/core/graph/graph_def_builder_util.h"
24#include "tensorflow/core/graph/subgraph.h"
25#include "tensorflow/core/kernels/ops_util.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/lib/core/status_test_util.h"
28#include "tensorflow/core/platform/test.h"
29
30// TODO(josh11b): Test setting the "device" field of a NodeDef.
31// TODO(josh11b): Test that feeding won't prune targets.
32
33namespace tensorflow {
34namespace {
35
36REGISTER_OP("TestParams").Output("o: float");
37REGISTER_OP("TestInput").Output("a: float").Output("b: float");
38REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
39
40// Compares that the order of nodes in 'inputs' respects the
41// pair orders described in 'ordered_pairs'.
42bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
43                  const std::vector<Node*>& inputs, string* error) {
44  for (const std::pair<string, string>& pair : ordered_pairs) {
45    const string& before_node = pair.first;
46    const string& after_node = pair.second;
47    bool seen_before = false;
48    bool seen_both = false;
49    for (const Node* node : inputs) {
50      if (!seen_before && after_node == node->name()) {
51        *error = strings::StrCat("Saw ", after_node, " before ", before_node);
52        return false;
53      }
54
55      if (before_node == node->name()) {
56        seen_before = true;
57      } else if (after_node == node->name()) {
58        seen_both = seen_before;
59        break;
60      }
61    }
62    if (!seen_both) {
63      *error = strings::StrCat("didn't see either ", before_node, " or ",
64                               after_node);
65      return false;
66    }
67  }
68
69  return true;
70}
71
72TEST(AlgorithmTest, ReversePostOrder) {
73  GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
74  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
75  Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
76  Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
77  Node* input =
78      SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
79  Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
80  BinaryOp("TestMul", w1, {input, 1},
81           b.opts().WithName("t2").WithControlInput(t1));
82  BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
83
84  Graph g(OpRegistry::Global());
85  TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
86  std::vector<Node*> order;
87
88  // Test reverse post order:
89  GetReversePostOrder(g, &order);
90
91  // Check that the order respects the dependencies correctly.
92  std::vector<std::pair<string, string>> reverse_orders = {
93      {"W1", "input"}, {"W1", "t1"},    {"W1", "t2"}, {"W1", "t3"},
94      {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
95  string error;
96  EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
97
98  // A false ordering should fail the check.
99  reverse_orders = {{"input", "W1"}};
100  EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
101
102  // Test post order:
103  GetPostOrder(g, &order);
104
105  // Check that the order respects the dependencies correctly.
106  std::vector<std::pair<string, string>> orders = {
107      {"input", "W1"}, {"t1", "W1"},    {"t2", "W1"}, {"t3", "W1"},
108      {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
109  EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
110
111  // A false ordering should fail the check.
112  orders = {{"W1", "t3"}};
113  EXPECT_FALSE(ExpectBefore(orders, order, &error));
114}
115
116TEST(AlgorithmTest, ReversePostOrderStable) {
117  int64 run_count = 100;
118  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
119
120  for (int64 i = 0; i < run_count; ++i) {
121    // One source of nondeterminism comes from unordered set with key of a
122    // pointer type, for example the order of FlatSet<Node*> depends on the
123    // raw pointer value of Node. Stable post order suppose to remove this
124    // nondeterminism by enforcing an ordering based on node ids.
125    GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
126    string error;
127    Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
128    Node* input =
129        SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
130    BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t2"));
131    // Insert different number of nodes between the allocation of t2 and t3,
132    // this creates enough entropy in the memory distance between t2 and t3 thus
133    // forces them to have randomized ordering had stable DFS was not
134    // implemented correctly.
135    for (int64 j = 0; j < i; ++j) {
136      BinaryOp("TestMul", w1, {input, 1},
137               b.opts().WithName(strings::StrCat("internal", j)));
138    }
139
140    BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
141
142    Graph g(OpRegistry::Global());
143    TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
144    std::vector<Node*> order;
145
146    // Test reverse post order generates expected ordering.
147    GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorID());
148    EXPECT_TRUE(ExpectBefore({{"t3", "t2"}}, order, &error));
149  }
150}
151}  // namespace
152}  // namespace tensorflow
153