1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/contrib/tensorrt/segment/segment.h" 17#include "tensorflow/c/c_api.h" 18#include "tensorflow/core/framework/graph.pb.h" 19#include "tensorflow/core/framework/node_def.pb.h" 20#include "tensorflow/core/lib/core/errors.h" 21#include "tensorflow/core/lib/core/status.h" 22#include "tensorflow/core/platform/test.h" 23#include "tensorflow/core/platform/types.h" 24 25namespace tensorflow { 26namespace tensorrt { 27namespace segment { 28namespace test { 29 30class SegmentTest : public ::testing::Test { 31 public: 32 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); 33 34 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name); 35 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 36 TF_Status* s, const char* name); 37 38 std::function<bool(const NodeDef&)> MakeCandidateFn( 39 const std::set<string>& node_names); 40 41 protected: 42 void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, 43 TF_Operation** op); 44 void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 45 TF_Status* s, const char* name, TF_Operation** op, bool check); 46 47 SegmentOptions default_options_; 48}; 49 50bool SegmentTest::GetGraphDef(TF_Graph* graph, 51 tensorflow::GraphDef* graph_def) { 52 TF_Status* s = TF_NewStatus(); 53 TF_Buffer* buffer = TF_NewBuffer(); 54 TF_GraphToGraphDef(graph, buffer, s); 55 bool ret = TF_GetCode(s) == TF_OK; 56 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 57 if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); 58 TF_DeleteBuffer(buffer); 59 TF_DeleteStatus(s); 60 return ret; 61} 62 63std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn( 64 const std::set<string>& node_names) { 65 return [node_names](const NodeDef& node) -> bool { 66 return node_names.find(node.name()) != node_names.end(); 67 }; 68} 69 70void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s, 71 const char* name, TF_Operation** op) { 72 TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); 73 TF_SetAttrType(desc, "dtype", TF_INT32); 74 *op = TF_FinishOperation(desc, s); 75 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 76 ASSERT_NE(*op, nullptr); 77} 78 79TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s, 80 const char* name) { 81 TF_Operation* op; 82 PlaceholderHelper(graph, s, name, &op); 83 return op; 84} 85 86void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 87 TF_Status* s, const char* name, TF_Operation** op, 88 bool check) { 89 TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); 90 TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; 91 TF_AddInputList(desc, add_inputs, 2); 92 *op = TF_FinishOperation(desc, s); 93 if (check) { 94 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 95 ASSERT_NE(*op, nullptr); 96 } 97} 98 99TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r, 100 TF_Graph* graph, TF_Status* s, 101 const char* name) { 102 TF_Operation* op; 103 AddHelper(l, r, graph, s, name, &op, true); 104 return op; 105} 106 107TEST_F(SegmentTest, Empty) { 108 TF_Graph* graph = TF_NewGraph(); 109 110 GraphDef graph_def; 111 ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 112 113 SegmentNodesVector segments; 114 ASSERT_EQ( 115 SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments), 116 tensorflow::Status::OK()); 117 118 // Expect no segments/subgraphs. 119 EXPECT_TRUE(segments.empty()); 120 TF_DeleteGraph(graph); 121} 122 123TEST_F(SegmentTest, Simple) { 124 TF_Status* s = TF_NewStatus(); 125 TF_Graph* graph = TF_NewGraph(); 126 127 // feed 128 // // || 129 // add0 add1 130 // | | / 131 // | add2 132 // | / || 133 // add3 add4 134 // | / 135 // <sink> 136 // 137 TF_Operation* feed = Placeholder(graph, s, "feed"); 138 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 139 EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); 140 141 TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); 142 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 143 TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); 144 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 145 TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); 146 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 147 TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); 148 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 149 EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); 150 TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); 151 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 152 EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); 153 154 GraphDef graph_def; 155 ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 156 157 SegmentNodesVector segments; 158 ASSERT_EQ( 159 SegmentGraph(graph_def, 160 MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}), 161 default_options_, &segments), 162 tensorflow::Status::OK()); 163 164 // Expect all Add operations to be collapsed into a single segment 165 ASSERT_EQ(segments.size(), 1); 166 std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"}; 167 for (const auto& ex : expected) { 168 EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) 169 << "Missing expected node " << ex; 170 } 171 TF_DeleteGraph(graph); 172 TF_DeleteStatus(s); 173} 174 175TEST_F(SegmentTest, AvoidCycle) { 176 TF_Status* s = TF_NewStatus(); 177 TF_Graph* graph = TF_NewGraph(); 178 179 // add2 is not a TRT candidate so add0/add3 cannot be formed as a 180 // subgraph 181 // 182 // feed 183 // // || 184 // add0 add1 185 // | | / 186 // | add2 187 // | / || 188 // add3 add4 189 // | / 190 // <sink> 191 // 192 TF_Operation* feed = Placeholder(graph, s, "feed"); 193 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 194 EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); 195 196 TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); 197 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 198 TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); 199 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 200 TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); 201 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 202 TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); 203 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 204 EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); 205 TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); 206 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 207 EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); 208 209 GraphDef graph_def; 210 ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 211 212 SegmentNodesVector segments; 213 ASSERT_EQ( 214 SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}), 215 default_options_, &segments), 216 tensorflow::Status::OK()); 217 218 // Expect no subgraphs 219 EXPECT_EQ(segments.size(), 0); 220 TF_DeleteGraph(graph); 221 TF_DeleteStatus(s); 222} 223 224TEST_F(SegmentTest, Multiple) { 225 TF_Status* s = TF_NewStatus(); 226 TF_Graph* graph = TF_NewGraph(); 227 228 // add5 is not a TRT candidate so two subgraphs should be formed 229 // 230 // feed 231 // // || || 232 // add0 add1 add7 233 // | | / / || 234 // | add2-----add5 add8 235 // | / | | | | 236 // add3 add4 add6 237 // | | / 238 // <sink> 239 // 240 TF_Operation* feed = Placeholder(graph, s, "feed"); 241 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 242 EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); 243 244 TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); 245 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 246 TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); 247 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 248 TF_Operation* add7 = Add(feed, feed, graph, s, "add7"); 249 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 250 TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); 251 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 252 TF_Operation* add5 = Add(add2, add7, graph, s, "add5"); 253 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 254 TF_Operation* add8 = Add(add7, add7, graph, s, "add8"); 255 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 256 TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); 257 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 258 EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); 259 TF_Operation* add4 = Add(add2, add5, graph, s, "add4"); 260 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 261 EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); 262 TF_Operation* add6 = Add(add5, add8, graph, s, "add6"); 263 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 264 EXPECT_EQ(string("add6"), string(TF_OperationName(add6))); 265 266 GraphDef graph_def; 267 ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 268 269 SegmentNodesVector segments; 270 ASSERT_EQ(SegmentGraph(graph_def, 271 MakeCandidateFn({"add0", "add1", "add2", "add3", 272 "add4", "add6", "add7", "add8"}), 273 default_options_, &segments), 274 tensorflow::Status::OK()); 275 276 // Expect two subgraphs 277 EXPECT_EQ(segments.size(), 2); 278 279 std::vector<string> expected0{"add0", "add1", "add2", "add3"}; 280 for (const auto& ex : expected0) { 281 EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) 282 << "Missing expected node " << ex; 283 } 284 285 std::vector<string> expected1{"add6", "add8"}; 286 for (const auto& ex : expected1) { 287 EXPECT_TRUE(segments[1].find(ex) != segments[1].end()) 288 << "Missing expected node " << ex; 289 } 290 TF_DeleteGraph(graph); 291 TF_DeleteStatus(s); 292} 293 294TEST_F(SegmentTest, BigIfElse) { 295 TF_Status* s = TF_NewStatus(); 296 TF_Graph* graph = TF_NewGraph(); 297 298 // add2 is not a TRT candidate 299 // 300 // feed 301 // || 302 // add0 303 // // || 304 // add1 add4 305 // || || 306 // add2 add5 307 // || || 308 // add3 add6 309 // || // 310 // add7 311 // || 312 // <sink> 313 // 314 TF_Operation* feed = Placeholder(graph, s, "feed"); 315 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 316 EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); 317 318 TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); 319 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 320 TF_Operation* add1 = Add(add0, add0, graph, s, "add1"); 321 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 322 TF_Operation* add2 = Add(add1, add1, graph, s, "add2"); 323 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 324 TF_Operation* add3 = Add(add2, add2, graph, s, "add3"); 325 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 326 TF_Operation* add4 = Add(add0, add0, graph, s, "add4"); 327 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 328 TF_Operation* add5 = Add(add4, add4, graph, s, "add5"); 329 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 330 TF_Operation* add6 = Add(add5, add5, graph, s, "add6"); 331 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 332 TF_Operation* add7 = Add(add3, add6, graph, s, "add7"); 333 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 334 EXPECT_EQ(string("add7"), string(TF_OperationName(add7))); 335 336 GraphDef graph_def; 337 ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 338 339 SegmentNodesVector segments; 340 ASSERT_EQ(SegmentGraph(graph_def, 341 MakeCandidateFn({"add0", "add1", "add3", "add4", 342 "add5", "add6", "add7"}), 343 default_options_, &segments), 344 tensorflow::Status::OK()); 345 346 // Expect 2 subgraphs 347 EXPECT_EQ(segments.size(), 2); 348 349 std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"}; 350 for (const auto& ex : expected0) { 351 EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) 352 << "Missing expected node " << ex; 353 } 354 355 std::vector<string> expected1{"add0", "add1"}; 356 for (const auto& ex : expected1) { 357 EXPECT_TRUE(segments[1].find(ex) != segments[1].end()) 358 << "Missing expected node " << ex; 359 } 360 TF_DeleteGraph(graph); 361 TF_DeleteStatus(s); 362} 363 364} // namespace test 365} // namespace segment 366} // namespace tensorrt 367} // namespace tensorflow 368