1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/const_op.h"
17#include "tensorflow/cc/ops/image_ops.h"
18#include "tensorflow/cc/ops/nn_ops.h"
19#include "tensorflow/cc/ops/sendrecv_ops.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/platform/test.h"
24#include "tensorflow/core/platform/test_benchmark.h"
25#include "tensorflow/core/public/session.h"
26#include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28namespace tensorflow {
29namespace graph_transforms {
30
31class SortByExecutionOrderTest : public ::testing::Test {
32 protected:
33  void GetOrder(const GraphDef& graph_def, std::map<string, int>* order) {
34    for (int i = 0; i < graph_def.node_size(); ++i) {
35      const NodeDef& node = graph_def.node(i);
36      (*order)[node.name()] = i;
37    }
38  }
39
40  void TestSimpleAdd() {
41    GraphDef graph_def;
42    NodeDef* add_node = graph_def.add_node();
43    add_node->set_name("add_node");
44    add_node->set_op("Add");
45    add_node->add_input("a_node");
46    add_node->add_input("b_node");
47
48    NodeDef* b_node = graph_def.add_node();
49    b_node->set_name("b_node");
50    b_node->set_op("Const");
51
52    NodeDef* a_node = graph_def.add_node();
53    a_node->set_name("a_node");
54    a_node->set_op("Const");
55
56    GraphDef result;
57    TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
58
59    std::map<string, int> order;
60    GetOrder(result, &order);
61    EXPECT_EQ(2, order["add_node"]);
62    EXPECT_GT(2, order["a_node"]);
63    EXPECT_GT(2, order["b_node"]);
64  }
65
66  void TestSimpleLinear() {
67    GraphDef graph_def;
68
69    NodeDef* negative_node = graph_def.add_node();
70    negative_node->set_name("negative_node");
71    negative_node->set_op("Negative");
72    negative_node->add_input("sqrt_node");
73
74    NodeDef* relu_node = graph_def.add_node();
75    relu_node->set_name("relu_node");
76    relu_node->set_op("Relu");
77    relu_node->add_input("const_node");
78
79    NodeDef* sqrt_node = graph_def.add_node();
80    sqrt_node->set_name("sqrt_node");
81    sqrt_node->set_op("Sqrt");
82    sqrt_node->add_input("relu_node");
83
84    NodeDef* const_node = graph_def.add_node();
85    const_node->set_name("const_node");
86    const_node->set_op("Const");
87
88    GraphDef result;
89    TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
90
91    std::map<string, int> order;
92    GetOrder(result, &order);
93    EXPECT_EQ(3, order["negative_node"]);
94    EXPECT_EQ(2, order["sqrt_node"]);
95    EXPECT_EQ(1, order["relu_node"]);
96    EXPECT_EQ(0, order["const_node"]);
97  }
98
99  void TestSimpleTree() {
100    GraphDef graph_def;
101
102    NodeDef* add_node1 = graph_def.add_node();
103    add_node1->set_name("add_node1");
104    add_node1->set_op("Add");
105    add_node1->add_input("add_node2");
106    add_node1->add_input("add_node3");
107
108    NodeDef* add_node2 = graph_def.add_node();
109    add_node2->set_name("add_node2");
110    add_node2->set_op("Add");
111    add_node2->add_input("const_node1");
112    add_node2->add_input("const_node2");
113
114    NodeDef* add_node3 = graph_def.add_node();
115    add_node3->set_name("add_node3");
116    add_node3->set_op("Add");
117    add_node3->add_input("const_node3");
118    add_node3->add_input("const_node4");
119
120    NodeDef* const_node1 = graph_def.add_node();
121    const_node1->set_name("const_node1");
122    const_node1->set_op("Const");
123
124    NodeDef* const_node2 = graph_def.add_node();
125    const_node2->set_name("const_node2");
126    const_node2->set_op("Const");
127
128    NodeDef* const_node3 = graph_def.add_node();
129    const_node3->set_name("const_node3");
130    const_node3->set_op("Const");
131
132    NodeDef* const_node4 = graph_def.add_node();
133    const_node4->set_name("const_node4");
134    const_node4->set_op("Const");
135
136    GraphDef result;
137    TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
138
139    std::map<string, int> order;
140    GetOrder(result, &order);
141    EXPECT_EQ(6, order["add_node1"]);
142    EXPECT_GT(6, order["add_node2"]);
143    EXPECT_GT(6, order["add_node3"]);
144    EXPECT_GT(5, order["const_node1"]);
145    EXPECT_GT(5, order["const_node2"]);
146    EXPECT_GT(5, order["const_node3"]);
147    EXPECT_GT(5, order["const_node4"]);
148  }
149
150  void TestCommonAncestor() {
151    GraphDef graph_def;
152
153    NodeDef* add_node1 = graph_def.add_node();
154    add_node1->set_name("add_node1");
155    add_node1->set_op("Add");
156    add_node1->add_input("add_node2");
157    add_node1->add_input("add_node3");
158
159    NodeDef* add_node2 = graph_def.add_node();
160    add_node2->set_name("add_node2");
161    add_node2->set_op("Add");
162    add_node2->add_input("const_node1");
163    add_node2->add_input("const_node2");
164
165    NodeDef* add_node3 = graph_def.add_node();
166    add_node3->set_name("add_node3");
167    add_node3->set_op("Add");
168    add_node3->add_input("const_node1");
169    add_node3->add_input("const_node3");
170
171    NodeDef* const_node1 = graph_def.add_node();
172    const_node1->set_name("const_node1");
173    const_node1->set_op("Const");
174
175    NodeDef* const_node2 = graph_def.add_node();
176    const_node2->set_name("const_node2");
177    const_node2->set_op("Const");
178
179    NodeDef* const_node3 = graph_def.add_node();
180    const_node3->set_name("const_node3");
181    const_node3->set_op("Const");
182
183    GraphDef result;
184    TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
185
186    std::map<string, int> order;
187    GetOrder(result, &order);
188    EXPECT_EQ(5, order["add_node1"]);
189    EXPECT_GT(5, order["add_node2"]);
190    EXPECT_GT(5, order["add_node3"]);
191    EXPECT_GT(4, order["const_node2"]);
192    EXPECT_GT(4, order["const_node3"]);
193    EXPECT_GT(3, order["const_node1"]);
194  }
195};
196
197TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); }
198
199TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); }
200
201TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); }
202
203TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); }
204
205}  // namespace graph_transforms
206}  // namespace tensorflow
207