1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/graph/algorithm.h" 17 18#include <string> 19#include <vector> 20 21#include "tensorflow/core/graph/graph.h" 22#include "tensorflow/core/graph/graph_def_builder.h" 23#include "tensorflow/core/graph/graph_def_builder_util.h" 24#include "tensorflow/core/graph/subgraph.h" 25#include "tensorflow/core/kernels/ops_util.h" 26#include "tensorflow/core/lib/core/status.h" 27#include "tensorflow/core/lib/core/status_test_util.h" 28#include "tensorflow/core/platform/test.h" 29 30// TODO(josh11b): Test setting the "device" field of a NodeDef. 31// TODO(josh11b): Test that feeding won't prune targets. 32 33namespace tensorflow { 34namespace { 35 36REGISTER_OP("TestParams").Output("o: float"); 37REGISTER_OP("TestInput").Output("a: float").Output("b: float"); 38REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); 39 40// Compares that the order of nodes in 'inputs' respects the 41// pair orders described in 'ordered_pairs'. 42bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs, 43 const std::vector<Node*>& inputs, string* error) { 44 for (const std::pair<string, string>& pair : ordered_pairs) { 45 const string& before_node = pair.first; 46 const string& after_node = pair.second; 47 bool seen_before = false; 48 bool seen_both = false; 49 for (const Node* node : inputs) { 50 if (!seen_before && after_node == node->name()) { 51 *error = strings::StrCat("Saw ", after_node, " before ", before_node); 52 return false; 53 } 54 55 if (before_node == node->name()) { 56 seen_before = true; 57 } else if (after_node == node->name()) { 58 seen_both = seen_before; 59 break; 60 } 61 } 62 if (!seen_both) { 63 *error = strings::StrCat("didn't see either ", before_node, " or ", 64 after_node); 65 return false; 66 } 67 } 68 69 return true; 70} 71 72TEST(AlgorithmTest, ReversePostOrder) { 73 GraphDefBuilder b(GraphDefBuilder::kFailImmediately); 74 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 75 Node* w1 = SourceOp("TestParams", b.opts().WithName("W1")); 76 Node* w2 = SourceOp("TestParams", b.opts().WithName("W2")); 77 Node* input = 78 SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1)); 79 Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1")); 80 BinaryOp("TestMul", w1, {input, 1}, 81 b.opts().WithName("t2").WithControlInput(t1)); 82 BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3")); 83 84 Graph g(OpRegistry::Global()); 85 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); 86 std::vector<Node*> order; 87 88 // Test reverse post order: 89 GetReversePostOrder(g, &order); 90 91 // Check that the order respects the dependencies correctly. 92 std::vector<std::pair<string, string>> reverse_orders = { 93 {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"}, 94 {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}}; 95 string error; 96 EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error; 97 98 // A false ordering should fail the check. 99 reverse_orders = {{"input", "W1"}}; 100 EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error)); 101 102 // Test post order: 103 GetPostOrder(g, &order); 104 105 // Check that the order respects the dependencies correctly. 106 std::vector<std::pair<string, string>> orders = { 107 {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"}, 108 {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}}; 109 EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error; 110 111 // A false ordering should fail the check. 112 orders = {{"W1", "t3"}}; 113 EXPECT_FALSE(ExpectBefore(orders, order, &error)); 114} 115 116TEST(AlgorithmTest, ReversePostOrderStable) { 117 int64 run_count = 100; 118 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 119 120 for (int64 i = 0; i < run_count; ++i) { 121 // One source of nondeterminism comes from unordered set with key of a 122 // pointer type, for example the order of FlatSet<Node*> depends on the 123 // raw pointer value of Node. Stable post order suppose to remove this 124 // nondeterminism by enforcing an ordering based on node ids. 125 GraphDefBuilder b(GraphDefBuilder::kFailImmediately); 126 string error; 127 Node* w1 = SourceOp("TestParams", b.opts().WithName("W1")); 128 Node* input = 129 SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1)); 130 BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t2")); 131 // Insert different number of nodes between the allocation of t2 and t3, 132 // this creates enough entropy in the memory distance between t2 and t3 thus 133 // forces them to have randomized ordering had stable DFS was not 134 // implemented correctly. 135 for (int64 j = 0; j < i; ++j) { 136 BinaryOp("TestMul", w1, {input, 1}, 137 b.opts().WithName(strings::StrCat("internal", j))); 138 } 139 140 BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3")); 141 142 Graph g(OpRegistry::Global()); 143 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); 144 std::vector<Node*> order; 145 146 // Test reverse post order generates expected ordering. 147 GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorID()); 148 EXPECT_TRUE(ExpectBefore({{"t3", "t2"}}, order, &error)); 149 } 150} 151} // namespace 152} // namespace tensorflow 153