1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/tensorrt/segment/segment.h"
17#include "tensorflow/c/c_api.h"
18#include "tensorflow/core/framework/graph.pb.h"
19#include "tensorflow/core/framework/node_def.pb.h"
20#include "tensorflow/core/lib/core/errors.h"
21#include "tensorflow/core/lib/core/status.h"
22#include "tensorflow/core/platform/test.h"
23#include "tensorflow/core/platform/types.h"
24
25namespace tensorflow {
26namespace tensorrt {
27namespace segment {
28namespace test {
29
30class SegmentTest : public ::testing::Test {
31 public:
32  bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
33
34  TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
35  TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
36                    TF_Status* s, const char* name);
37
38  std::function<bool(const NodeDef&)> MakeCandidateFn(
39      const std::set<string>& node_names);
40
41 protected:
42  void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
43                         TF_Operation** op);
44  void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
45                 TF_Status* s, const char* name, TF_Operation** op, bool check);
46
47  SegmentOptions default_options_;
48};
49
50bool SegmentTest::GetGraphDef(TF_Graph* graph,
51                              tensorflow::GraphDef* graph_def) {
52  TF_Status* s = TF_NewStatus();
53  TF_Buffer* buffer = TF_NewBuffer();
54  TF_GraphToGraphDef(graph, buffer, s);
55  bool ret = TF_GetCode(s) == TF_OK;
56  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
57  if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
58  TF_DeleteBuffer(buffer);
59  TF_DeleteStatus(s);
60  return ret;
61}
62
63std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
64    const std::set<string>& node_names) {
65  return [node_names](const NodeDef& node) -> bool {
66    return node_names.find(node.name()) != node_names.end();
67  };
68}
69
70void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
71                                    const char* name, TF_Operation** op) {
72  TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
73  TF_SetAttrType(desc, "dtype", TF_INT32);
74  *op = TF_FinishOperation(desc, s);
75  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
76  ASSERT_NE(*op, nullptr);
77}
78
79TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
80                                       const char* name) {
81  TF_Operation* op;
82  PlaceholderHelper(graph, s, name, &op);
83  return op;
84}
85
86void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
87                            TF_Status* s, const char* name, TF_Operation** op,
88                            bool check) {
89  TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
90  TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
91  TF_AddInputList(desc, add_inputs, 2);
92  *op = TF_FinishOperation(desc, s);
93  if (check) {
94    ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
95    ASSERT_NE(*op, nullptr);
96  }
97}
98
99TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
100                               TF_Graph* graph, TF_Status* s,
101                               const char* name) {
102  TF_Operation* op;
103  AddHelper(l, r, graph, s, name, &op, true);
104  return op;
105}
106
107TEST_F(SegmentTest, Empty) {
108  TF_Graph* graph = TF_NewGraph();
109
110  GraphDef graph_def;
111  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
112
113  SegmentNodesVector segments;
114  ASSERT_EQ(
115      SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
116      tensorflow::Status::OK());
117
118  // Expect no segments/subgraphs.
119  EXPECT_TRUE(segments.empty());
120  TF_DeleteGraph(graph);
121}
122
123TEST_F(SegmentTest, Simple) {
124  TF_Status* s = TF_NewStatus();
125  TF_Graph* graph = TF_NewGraph();
126
127  //           feed
128  //         //    ||
129  //       add0    add1
130  //        | |    /
131  //        |  add2
132  //        |  /  ||
133  //       add3    add4
134  //           |  /
135  //          <sink>
136  //
137  TF_Operation* feed = Placeholder(graph, s, "feed");
138  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
139  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
140
141  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
142  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
143  TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
144  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
145  TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
146  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
147  TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
148  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
149  EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
150  TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
151  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
152  EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
153
154  GraphDef graph_def;
155  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
156
157  SegmentNodesVector segments;
158  ASSERT_EQ(
159      SegmentGraph(graph_def,
160                   MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
161                   default_options_, &segments),
162      tensorflow::Status::OK());
163
164  // Expect all Add operations to be collapsed into a single segment
165  ASSERT_EQ(segments.size(), 1);
166  std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
167  for (const auto& ex : expected) {
168    EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
169        << "Missing expected node " << ex;
170  }
171  TF_DeleteGraph(graph);
172  TF_DeleteStatus(s);
173}
174
175TEST_F(SegmentTest, AvoidCycle) {
176  TF_Status* s = TF_NewStatus();
177  TF_Graph* graph = TF_NewGraph();
178
179  // add2 is not a TRT candidate so add0/add3 cannot be formed as a
180  // subgraph
181  //
182  //           feed
183  //         //    ||
184  //       add0    add1
185  //        | |    /
186  //        |  add2
187  //        |  /  ||
188  //       add3    add4
189  //           |  /
190  //          <sink>
191  //
192  TF_Operation* feed = Placeholder(graph, s, "feed");
193  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
194  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
195
196  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
197  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
198  TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
199  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
200  TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
201  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
202  TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
203  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
204  EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
205  TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
206  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
207  EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
208
209  GraphDef graph_def;
210  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
211
212  SegmentNodesVector segments;
213  ASSERT_EQ(
214      SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
215                   default_options_, &segments),
216      tensorflow::Status::OK());
217
218  // Expect no subgraphs
219  EXPECT_EQ(segments.size(), 0);
220  TF_DeleteGraph(graph);
221  TF_DeleteStatus(s);
222}
223
224TEST_F(SegmentTest, Multiple) {
225  TF_Status* s = TF_NewStatus();
226  TF_Graph* graph = TF_NewGraph();
227
228  // add5 is not a TRT candidate so two subgraphs should be formed
229  //
230  //                feed
231  //         //      ||     ||
232  //       add0    add1      add7
233  //        | |    /        /   ||
234  //        |  add2-----add5    add8
235  //        |  /  |    |  |    |
236  //       add3   add4     add6
237  //           |     |     /
238  //               <sink>
239  //
240  TF_Operation* feed = Placeholder(graph, s, "feed");
241  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
242  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
243
244  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
245  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
246  TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
247  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
248  TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
249  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
250  TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
251  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
252  TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
253  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
254  TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
255  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
256  TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
257  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
258  EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
259  TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
260  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
261  EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
262  TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
263  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
264  EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
265
266  GraphDef graph_def;
267  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
268
269  SegmentNodesVector segments;
270  ASSERT_EQ(SegmentGraph(graph_def,
271                         MakeCandidateFn({"add0", "add1", "add2", "add3",
272                                          "add4", "add6", "add7", "add8"}),
273                         default_options_, &segments),
274            tensorflow::Status::OK());
275
276  // Expect two subgraphs
277  EXPECT_EQ(segments.size(), 2);
278
279  std::vector<string> expected0{"add0", "add1", "add2", "add3"};
280  for (const auto& ex : expected0) {
281    EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
282        << "Missing expected node " << ex;
283  }
284
285  std::vector<string> expected1{"add6", "add8"};
286  for (const auto& ex : expected1) {
287    EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
288        << "Missing expected node " << ex;
289  }
290  TF_DeleteGraph(graph);
291  TF_DeleteStatus(s);
292}
293
294TEST_F(SegmentTest, BigIfElse) {
295  TF_Status* s = TF_NewStatus();
296  TF_Graph* graph = TF_NewGraph();
297
298  // add2 is not a TRT candidate
299  //
300  //           feed
301  //            ||
302  //           add0
303  //         //    ||
304  //       add1    add4
305  //        ||      ||
306  //       add2    add5
307  //        ||      ||
308  //       add3    add6
309  //         ||    //
310  //           add7
311  //            ||
312  //          <sink>
313  //
314  TF_Operation* feed = Placeholder(graph, s, "feed");
315  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
316  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
317
318  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
319  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
320  TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
321  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
322  TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
323  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
324  TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
325  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
326  TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
327  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
328  TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
329  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
330  TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
331  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
332  TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
333  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
334  EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
335
336  GraphDef graph_def;
337  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
338
339  SegmentNodesVector segments;
340  ASSERT_EQ(SegmentGraph(graph_def,
341                         MakeCandidateFn({"add0", "add1", "add3", "add4",
342                                          "add5", "add6", "add7"}),
343                         default_options_, &segments),
344            tensorflow::Status::OK());
345
346  // Expect 2 subgraphs
347  EXPECT_EQ(segments.size(), 2);
348
349  std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
350  for (const auto& ex : expected0) {
351    EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
352        << "Missing expected node " << ex;
353  }
354
355  std::vector<string> expected1{"add0", "add1"};
356  for (const auto& ex : expected1) {
357    EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
358        << "Missing expected node " << ex;
359  }
360  TF_DeleteGraph(graph);
361  TF_DeleteStatus(s);
362}
363
364}  // namespace test
365}  // namespace segment
366}  // namespace tensorrt
367}  // namespace tensorflow
368