c_api_function_test.cc revision 9624d165f1f2c717eda96464fee8bf7229cc14f5
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/c/c_api.h" 17 18#include "tensorflow/c/c_test_util.h" 19#include "tensorflow/core/framework/function.pb.h" 20#include "tensorflow/core/framework/op_def.pb.h" 21#include "tensorflow/core/lib/strings/str_util.h" 22#include "tensorflow/core/lib/strings/strcat.h" 23#include "tensorflow/core/platform/logging.h" 24#include "tensorflow/core/platform/test.h" 25 26namespace tensorflow { 27namespace { 28 29// Specification for expected input/output and its type. 30// DataType value of DT_INVALID signifies that we don't want to 31// check the data type. 32typedef std::pair<string, DataType> IOSpec; 33 34std::vector<IOSpec> M(const std::initializer_list<string>& names) { 35 std::vector<IOSpec> v; 36 for (const string& name : names) { 37 v.push_back(IOSpec(name, DT_INVALID)); 38 } 39 return v; 40} 41 42// Specification for an expected edge. 43// src is either: 44// - input name (as it appears in FunctionDef) 45// - name of output tensor (in nested "add:z:0" format) 46// dst is either: 47// - output name (as it appears in FunctionDef) 48// - <name_of_node>:<index_of_this_input_into_node> (this looks the same as 49// output tensor naming, but it the index is actually an input index) 50struct EdgeSpec : public std::pair<string, string> { 51 typedef std::pair<string, string> Base; 52 53 // Inherit the set of constructors 54 using Base::pair; 55 56 string ToString() const { return strings::StrCat(first, "->", second); } 57}; 58 59class CApiFunctionTest : public ::testing::Test { 60 protected: 61 CApiFunctionTest() 62 : s_(TF_NewStatus()), 63 func_graph_(TF_NewGraph()), 64 host_graph_(TF_NewGraph()), 65 func_(nullptr) {} 66 67 void SetUp() override {} 68 69 ~CApiFunctionTest() override { 70 TF_DeleteFunction(func_); 71 TF_DeleteGraph(host_graph_); 72 TF_DeleteGraph(func_graph_); 73 TF_DeleteStatus(s_); 74 } 75 76 void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs, 77 TF_Operation* output, int32_t expected_result) { 78 Run(inputs, {{output, 0}}, {expected_result}); 79 } 80 81 // Run the host graph, which now contains a function and check that 82 // outputs are as expected. 83 // 'T' stands for 'tensor' since the outputs are tensors, not scalars. 84 void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs, 85 std::initializer_list<TF_Output> outputs, 86 const std::vector<std::vector<int32_t>>& expected_results) { 87 // Create a session for this graph 88 CSession csession(host_graph_, s_); 89 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 90 91 // Run 92 csession.SetInputs(inputs); 93 csession.SetOutputs(outputs); 94 csession.Run(s_); 95 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 96 97 // Check results 98 for (int i = 0; i < expected_results.size(); ++i) { 99 TF_Tensor* out = csession.output_tensor(i); 100 ASSERT_TRUE(out != nullptr); 101 EXPECT_EQ(TF_INT32, TF_TensorType(out)); 102 EXPECT_EQ(1, TF_NumDims(out)); 103 CompareInt32Tensor(expected_results[i], out); 104 } 105 } 106 107 // Run the host graph, which now contains a function and check that 108 // outputs are as expected. 109 void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs, 110 std::initializer_list<TF_Output> outputs, 111 const std::vector<int32_t>& expected_results) { 112 // Create a session for this graph. 113 CSession csession(host_graph_, s_); 114 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 115 116 csession.SetInputs(inputs); 117 csession.SetOutputs(outputs); 118 csession.Run(s_); 119 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 120 121 for (int i = 0; i < expected_results.size(); ++i) { 122 TF_Tensor* out = csession.output_tensor(i); 123 ASSERT_TRUE(out != nullptr); 124 EXPECT_EQ(TF_INT32, TF_TensorType(out)); 125 EXPECT_EQ(0, TF_NumDims(out)); // scalar 126 ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out)); 127 int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out)); 128 EXPECT_EQ(expected_results[i], *output_contents); 129 } 130 } 131 132 void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) { 133 int32_t* data = static_cast<int32_t*>(TF_TensorData(t)); 134 size_t size = TF_TensorByteSize(t); 135 ASSERT_EQ(expected.size() * sizeof(int32_t), size); 136 for (int i = 0; i < expected.size(); ++i) { 137 ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i; 138 } 139 } 140 141 std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) { 142 std::vector<TF_Output> out; 143 for (auto op : ops) { 144 out.push_back({op, 0}); 145 } 146 return out; 147 } 148 149 void Define(int num_opers, const std::vector<TF_Operation*>& opers, 150 const std::vector<TF_Operation*>& inputs, 151 const std::vector<TF_Operation*>& outputs, 152 const char** output_names, bool expect_failure = false) { 153 DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names, 154 expect_failure); 155 } 156 157 // An explicit `num_opers` is needed so that we can distinguish between the 158 // case of no operations specified (-1) and the case of an empty set of 159 // operations specified (0). 160 void DefineT(int num_opers, const std::vector<TF_Operation*>& opers, 161 const std::vector<TF_Output>& inputs, 162 const std::vector<TF_Output>& outputs, const char** output_names, 163 bool expect_failure = false) { 164 ASSERT_EQ(func_, nullptr); 165 func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers, 166 num_opers == -1 ? nullptr : opers.data(), 167 inputs.size(), inputs.data(), outputs.size(), 168 outputs.data(), output_names, 169 /*opts=*/nullptr, s_); 170 if (expect_failure) { 171 ASSERT_EQ(func_, nullptr); 172 return; 173 } 174 175 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 176 ASSERT_NE(func_, nullptr); 177 TF_GraphAddFunction(host_graph_, func_, s_); 178 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 179 } 180 181 TF_Operation* Use(const std::vector<TF_Operation*>& inputs) { 182 return UseT(ToOutput(inputs)); 183 } 184 185 TF_Operation* UseT(const std::vector<TF_Output>& inputs) { 186 TF_Operation* op; 187 UseHelper(inputs, &op); 188 return op; 189 } 190 191 // All the *Helper methods are used as a workaround for the restrictions that 192 // one cannot call ASSERT_* methods in non-void-returning functions (when 193 // exceptions are disabled during compilation) 194 void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) { 195 TF_OperationDescription* desc = 196 TF_NewOperation(host_graph_, func_name_, func_node_name_); 197 for (auto input : inputs) { 198 TF_AddInput(desc, input); 199 } 200 // Set device to CPU because some ops inside the function might not be 201 // available on GPU. 202 TF_SetDevice(desc, "/cpu:0"); 203 *op = TF_FinishOperation(desc, s_); 204 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 205 ASSERT_NE(*op, nullptr); 206 } 207 208 FunctionDef fdef() { 209 tensorflow::FunctionDef fdef; 210 EXPECT_TRUE(GetFunctionDef(func_, &fdef)); 211 return fdef; 212 } 213 214 // logging utility 215 template <class Container> 216 string ToString(const Container& v) { 217 std::stringstream ss; 218 ss << "{"; 219 size_t i = 0; 220 for (const auto& e : v) { 221 if (i != 0) { 222 ss << ", "; 223 } 224 ss << e.ToString(); 225 ++i; 226 } 227 ss << "}"; 228 return ss.str(); 229 } 230 231 void VerifyFDefNodes(const tensorflow::FunctionDef& fdef, 232 const std::unordered_set<string>& nodes) { 233 ASSERT_EQ(nodes.size(), fdef.node_def_size()) 234 << "Got unexpected number of nodes. Expected: [" 235 << str_util::Join(nodes, ", ") 236 << "] Actual nodes in fdef: " << fdef.DebugString(); 237 for (const NodeDef& node_def : fdef.node_def()) { 238 ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end()) 239 << "Got unexpected node: " << node_def.name() 240 << " in fdef: " << fdef.DebugString(); 241 } 242 } 243 244 void VerifyFDefInputs(const tensorflow::FunctionDef& fdef, 245 const std::vector<IOSpec>& inputs) { 246 const OpDef& signature = fdef.signature(); 247 ASSERT_EQ(inputs.size(), signature.input_arg_size()); 248 for (int i = 0; i < inputs.size(); ++i) { 249 const OpDef::ArgDef& arg = signature.input_arg(i); 250 const IOSpec& in = inputs[i]; 251 if (in.second != DT_INVALID) { 252 ASSERT_EQ(arg.type(), in.second) 253 << "Got unexpected type for input " << i 254 << ". fdef: " << fdef.DebugString(); 255 } 256 ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i 257 << ". fdef: " << fdef.DebugString(); 258 } 259 } 260 261 void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef, 262 const std::vector<IOSpec>& outputs) { 263 const OpDef& signature = fdef.signature(); 264 ASSERT_EQ(outputs.size(), signature.output_arg_size()); 265 for (int i = 0; i < outputs.size(); ++i) { 266 const OpDef::ArgDef& arg = signature.output_arg(i); 267 const IOSpec& out = outputs[i]; 268 if (out.second != DT_INVALID) { 269 ASSERT_EQ(arg.type(), out.second) 270 << "Got unexpected type for output " << i 271 << ". fdef: " << fdef.DebugString(); 272 } 273 ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i 274 << ". fdef: " << fdef.DebugString(); 275 } 276 } 277 278 void VerifyFDefEdges( 279 const tensorflow::FunctionDef& fdef, 280 const std::vector<EdgeSpec>& e_edges, // expected edges 281 const std::vector<EdgeSpec>& c_edges, // expected ctrl edges 282 bool is_exact_edges = true) { 283 // Build a set of edges from fdef 284 std::set<EdgeSpec> a_edges; // actual edges 285 // Get edges from inputs to body nodes and between body nodes 286 for (const NodeDef& node_def : fdef.node_def()) { 287 for (int i = 0; i < node_def.input_size(); ++i) { 288 const string& in = node_def.input(i); 289 const auto& v = 290 a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)}); 291 ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> " 292 << strings::StrCat(node_def.name(), ":", i) 293 << ". fdef: " << fdef.DebugString(); 294 } 295 } 296 // Get edges from body nodes to outputs and from inputs to outputs 297 for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) { 298 const auto& iter = fdef.ret().find(arg.name()); 299 if (iter != fdef.ret().end()) { 300 const auto& v = a_edges.insert({iter->second, arg.name()}); 301 ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> " 302 << arg.name() << ". fdef: " << fdef.DebugString(); 303 } else { 304 const auto& v = a_edges.insert({arg.name(), arg.name()}); 305 ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> " 306 << arg.name() << ". fdef: " << fdef.DebugString(); 307 } 308 } 309 310 // Verify edges 311 for (const EdgeSpec& e : e_edges) { 312 ASSERT_TRUE(a_edges.find(e) != a_edges.end()) 313 << "Failed to find expected edge " << e.ToString() 314 << " in fdef: " << fdef.DebugString(); 315 } 316 317 // If caller specified all edges, check that we have seen all 318 if (is_exact_edges) { 319 ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size()) 320 << "Expected edges: " << ToString(e_edges) 321 << " Expected Control edges: " << ToString(c_edges) 322 << " Actual edges: " << ToString(a_edges) 323 << " in fdef: " << fdef.DebugString(); 324 } 325 } 326 327 void VerifyFDef(const std::unordered_set<string>& nodes, 328 const std::vector<IOSpec>& inputs, 329 const std::vector<IOSpec>& outputs, 330 const std::vector<EdgeSpec>& e_edges, // expected edges 331 const std::vector<EdgeSpec>& c_edges, // expected ctrl edges 332 bool is_exact_edges = true) { 333 tensorflow::FunctionDef fdef; 334 ASSERT_TRUE(GetFunctionDef(func_, &fdef)); 335 VerifyFDefNodes(fdef, nodes); 336 VerifyFDefInputs(fdef, inputs); 337 VerifyFDefOutputs(fdef, outputs); 338 VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges); 339 } 340 341 const char* func_name_ = "MyFunc"; 342 const char* func_node_name_ = "MyFunc_0"; 343 TF_Status* s_; 344 TF_Graph* func_graph_; 345 TF_Graph* host_graph_; 346 TF_Function* func_; 347 348 // Workaround for not being able to initialize empty map using {} 349 std::unordered_set<string> empty_; 350}; 351 352TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) { 353 /* 354 * constant 355 * | 356 * v 357 */ 358 // Define 359 TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10"); 360 Define(-1, {}, {}, {c}, nullptr); 361 362 // Use, run, and verify 363 TF_Operation* func_op = Use({}); 364 Run({}, func_op, 10); 365 VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}}, 366 {{"scalar10_0:output:0", "scalar10"}}, {}); 367} 368 369TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) { 370 /* 371 * | 372 * v 373 * negate 374 * | 375 * v 376 */ 377 // Define 378 TF_Operation* feed = Placeholder(func_graph_, s_); 379 TF_Operation* neg = Neg(feed, func_graph_, s_); 380 Define(-1, {}, {feed}, {neg}, nullptr); 381 382 // Use, run, and verify 383 TF_Operation* func_feed = Placeholder(host_graph_, s_); 384 TF_Operation* func_op = Use({func_feed}); 385 Run({{func_feed, Int32Tensor(3)}}, func_op, -3); 386 VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}}, 387 {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {}); 388} 389 390TEST_F(CApiFunctionTest, ZeroOps_Identity) { 391 /* 392 * | 393 * | 394 * | 395 * v 396 */ 397 // Define 398 TF_Operation* feed = Placeholder(func_graph_, s_); 399 Define(-1, {}, {feed}, {feed}, nullptr); 400 401 // Use, run, and verify 402 TF_Operation* func_feed = Placeholder(host_graph_, s_); 403 TF_Operation* func_op = Use({func_feed}); 404 Run({{func_feed, Int32Tensor(3)}}, func_op, 3); 405 VerifyFDef(empty_, {{"feed", DT_INT32}}, {{"feed_0", DT_INT32}}, 406 {{"feed", "feed_0"}}, {}); 407} 408 409TEST_F(CApiFunctionTest, ZeroOps_Permutation) { 410 /* 411 * | | 412 * \ / 413 * \/ 414 * x 415 * /\ 416 * / \ 417 * | | 418 * v v 419 */ 420 // Define 421 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 422 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 423 Define(-1, {}, {feed1, feed2}, {feed2, feed1}, nullptr); 424 425 // Use, run, and verify 426 TF_Operation* two = ScalarConst(2, host_graph_, s_); 427 TF_Operation* func_feed = Placeholder(host_graph_, s_); 428 TF_Operation* func_op = Use({two, func_feed}); 429 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2}); 430 VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"feed2_0"}, {"feed1_0"}}), 431 {{"feed1", "feed1_0"}, {"feed2", "feed2_0"}}, {}); 432} 433 434TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) { 435 /* 436 * | | 437 * v v 438 * add 439 * | 440 * v 441 */ 442 // Define 443 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 444 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 445 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 446 Define(-1, {}, {feed1, feed2}, {add}, nullptr); 447 448 // Use, run, and verify 449 TF_Operation* two = ScalarConst(2, host_graph_, s_); 450 TF_Operation* func_feed = Placeholder(host_graph_, s_); 451 TF_Operation* func_op = Use({two, func_feed}); 452 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); 453 VerifyFDef( 454 {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), 455 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {}); 456} 457 458TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) { 459 /* 460 * | | 461 * v v 462 * add 463 * 464 * (output ignored) 465 */ 466 // Define 467 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 468 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 469 Add(feed1, feed2, func_graph_, s_); 470 Define(-1, {}, {feed1, feed2}, {}, nullptr); 471 472 // Use, run, and verify 473 TF_Operation* two = ScalarConst(2, host_graph_, s_); 474 TF_Operation* func_feed = Placeholder(host_graph_, s_); 475 Use({two, func_feed}); 476 VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {}, 477 {{"feed1", "add:0"}, {"feed2", "add:1"}}, {}); 478} 479 480TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) { 481 /* 482 * | | | 483 * v v / 484 * add1 / 485 * | | 486 * v v 487 * add2 488 * | 489 * v 490 */ 491 // Define 492 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 493 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 494 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); 495 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); 496 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); 497 Define(-1, {}, {feed1, feed2, feed3}, {add2}, nullptr); 498 499 // Use, run, and verify 500 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); 501 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); 502 TF_Operation* func_feed = Placeholder(host_graph_, s_); 503 TF_Operation* func_op = Use({two, ten, func_feed}); 504 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3); 505 VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), 506 M({{"add2"}}), 507 {{"feed1", "add1:0"}, 508 {"feed2", "add1:1"}, 509 {"add1:sum:0", "add2_0:0"}, 510 {"feed3", "add2_0:1"}, 511 {"add2_0:sum:0", "add2"}}, 512 {}); 513} 514 515TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) { 516 /* 517 * | | 518 * v v 519 * add 520 * | 521 * +-+-+ 522 * | | 523 * v v 524 */ 525 // Define 526 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 527 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 528 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 529 Define(-1, {}, {feed1, feed2}, {add, add}, nullptr); 530 531 // Use, run, and verify 532 TF_Operation* two = ScalarConst(2, host_graph_, s_); 533 TF_Operation* func_feed = Placeholder(host_graph_, s_); 534 TF_Operation* func_op = Use({two, func_feed}); 535 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5}); 536 VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}), 537 {{"feed1", "add_1:0"}, 538 {"feed2", "add_1:1"}, 539 {"add_1:sum:0", "add"}, 540 {"add_1:sum:0", "add_0"}}, 541 {}); 542} 543 544TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) { 545 /* 546 * | | | 547 * v v / 548 * add / 549 * | | 550 * +-+ | 551 * | | | 552 * | v v 553 * | add 554 * | | 555 * v v 556 */ 557 // Define 558 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 559 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 560 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); 561 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); 562 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); 563 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, nullptr); 564 565 // Use, run, and verify 566 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); 567 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); 568 TF_Operation* func_feed = Placeholder(host_graph_, s_); 569 TF_Operation* func_op = Use({two, ten, func_feed}); 570 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15}); 571 VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), 572 M({{"add1"}, {"add2"}}), 573 {{"feed1", "add1_0:0"}, 574 {"feed2", "add1_0:1"}, 575 {"add1_0:sum:0", "add2_0:0"}, 576 {"feed3", "add2_0:1"}, 577 {"add1_0:sum:0", "add1"}, 578 {"add2_0:sum:0", "add2"}}, 579 {}); 580} 581 582TEST_F(CApiFunctionTest, FromSubsetOfOps) { 583 /* 584 * | | | 585 * v v / 586 * add / 587 * | | 588 * +---+--+---+ 589 * Ops used | | | | 590 * for func | v v | 591 * | | add | 592 * +-------> | | | 593 * | v | 594 * | | 595 * +----------+ 596 */ 597 // Define 598 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 599 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 600 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); 601 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); 602 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); 603 Define(1, {add2}, {add1, feed3}, {add2}, nullptr); 604 605 // Use, run, and verify 606 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); 607 TF_Operation* func_feed = Placeholder(host_graph_, s_); 608 TF_Operation* func_op = Use({two, func_feed}); 609 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); 610 VerifyFDef( 611 {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}), 612 {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}}, 613 {}); 614} 615 616TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) { 617 /* 618 * feed 619 * | 620 * +---------+---+ 621 * | const0 | | 622 * | | | | 623 * | v / | 624 * | split | 625 * | | | | | 626 * | v | v | 627 * | | | 628 * +------+------+ 629 * | 630 * v 631 * 632 * Only the second output from split is used as function output 633 */ 634 // Define 635 TF_Operation* feed = Placeholder(func_graph_, s_); 636 TF_Operation* split = Split3(feed, func_graph_, s_); 637 DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, nullptr); 638 639 // Use, run, and verify 640 TF_Operation* func_feed = Placeholder(host_graph_, s_); 641 TF_Operation* func_op = Use({func_feed}); 642 RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}}, 643 {{3, 4}}); 644 VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}), 645 {{"split3_const0:output:0", "split3_0:0"}, 646 {"feed", "split3_0:1"}, 647 {"split3_0:output:1", "split3"}}, 648 {}); 649} 650 651TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) { 652 /* 653 * feed 654 * | 655 * +---------+---+ 656 * | const0 | | 657 * | | | | 658 * | v / | 659 * | split | 660 * | | | | | 661 * | | v | | 662 * | | | | 663 * +---+-----+---+ 664 * | | 665 * v v 666 * 667 * Second output from split is not used as function output 668 */ 669 // Define 670 TF_Operation* feed = Placeholder(func_graph_, s_); 671 TF_Operation* split = Split3(feed, func_graph_, s_); 672 DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, nullptr); 673 674 // Use, run, and verify 675 TF_Operation* func_feed = Placeholder(host_graph_, s_); 676 TF_Operation* func_op = Use({func_feed}); 677 RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, 678 {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}}); 679 VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}), 680 M({{"split3"}, {"split3_0"}}), 681 {{"split3_const0:output:0", "split3_1:0"}, 682 {"feed", "split3_1:1"}, 683 {"split3_1:output:0", "split3"}, 684 {"split3_1:output:2", "split3_0"}}, 685 {}); 686} 687 688TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) { 689 /* 690 * | 691 * v 692 * split 693 * | | | 694 * | v | 695 * | | 696 * +---+-----+---+ 697 * | | | | 698 * | v v | 699 * | add | 700 * | | | 701 * | | | 702 * +------+------+ 703 * | 704 * v 705 */ 706 // Define 707 TF_Operation* feed = Placeholder(func_graph_, s_); 708 TF_Operation* split = Split3(feed, func_graph_, s_); 709 TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_); 710 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 711 DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr); 712 713 // Use, run, and verify 714 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); 715 TF_Operation* func_feed = Placeholder(host_graph_, s_); 716 TF_Operation* func_op = Use({two, func_feed}); 717 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); 718 VerifyFDef( 719 {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}), 720 {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}}, 721 {}); 722} 723 724TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) { 725 /* 726 * | 727 * v 728 * split 729 * | | | 730 * | v | 731 * | | 732 * input --->| |<--- input 733 * | | 734 * v v 735 * add 736 * | 737 * | 738 * v 739 */ 740 // Define 741 TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3}); 742 TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array"); 743 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 744 TF_Operation* split = Split3(c, func_graph_, s_); 745 TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_); 746 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 747 DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr, true); 748 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 749 EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in " 750 "`inputs` must have a single output. Node split3 has " 751 "3 outputs. Encountered while creating function 'MyFunc'"), 752 string(TF_Message(s_))); 753 754 TF_DeleteTensor(tensor_123); 755} 756 757TEST_F(CApiFunctionTest, FunctionWithWhileLoop) { 758 // Inputs to the while loop and the function as a whole 759 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 760 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 761 762 // Outputs of the while loop corresponding to the two inputs above 763 // The first one will the function's output 764 std::vector<TF_Output> outputs; 765 766 // Add while loop to func_graph_ 767 { 768 // The inputs to the while loop 769 std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}}; 770 std::unique_ptr<TF_WhileParams> params(new TF_WhileParams( 771 TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_))); 772 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 773 params->name = "test_loop"; 774 775 // Initialize outputs so we can easily detect errors/bugs 776 outputs.resize(2, {nullptr, -1}); 777 778 // Create loop: while (input1 < input2) input1 += input2 + 1 779 TF_Operation* less_than = LessThan( 780 params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_); 781 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 782 params->cond_output = {less_than, 0}; 783 784 TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1], 785 params->body_graph, s_, "add1"); 786 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 787 TF_Operation* one = ScalarConst(1, params->body_graph, s_); 788 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 789 TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2"); 790 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 791 params->body_outputs[0] = {add2, 0}; 792 params->body_outputs[1] = params->body_inputs[1]; 793 794 // Finalize while loop 795 TF_FinishWhile(params.get(), s_, &outputs[0]); 796 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 797 } 798 799 // Define function, use it in graph, and run 800 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, nullptr); 801 TF_Operation* five = ScalarConst(5, host_graph_, s_, "five"); 802 TF_Operation* func_feed = Placeholder(host_graph_, s_); 803 TF_Operation* func_op = Use({func_feed, five}); 804 Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1); 805 806 // Verify input, output, and subset of edges in fdef. 807 // The subset of edges we verify is a chain between feed1 and output to 808 // make sure that the correct output is picked. 809 tensorflow::FunctionDef fdef; 810 ASSERT_TRUE(GetFunctionDef(func_, &fdef)); 811 VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}})); 812 VerifyFDefOutputs(fdef, M({{"test_loop_exit"}})); 813 VerifyFDefEdges(fdef, 814 {{"feed1", "test_loop/Enter:0"}, 815 {"test_loop/Enter:output:0", "test_loop/Merge:0"}, 816 {"test_loop/Merge:output:0", "test_loop/Switch:0"}, 817 {"test_loop/Switch:output_false:0", "test_loop/Exit:0"}, 818 {"test_loop/Exit:output:0", "test_loop_exit"}}, 819 {}, false); 820} 821 822TEST_F(CApiFunctionTest, ControlDependency) { 823 /* 824 * | | scalar 825 * | | . 826 * v v . <---- control dependency 827 * add < - 828 * | 829 * v 830 */ 831 // Define 832 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 833 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 834 TF_Operation* five = ScalarConst(5, func_graph_, s_); 835 TF_Operation* add = 836 AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_); 837 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 838 Define(-1, {}, {feed1, feed2}, {add}, nullptr); 839 840 // Use, run, and verify 841 TF_Operation* two = ScalarConst(2, host_graph_, s_); 842 TF_Operation* func_feed = Placeholder(host_graph_, s_); 843 TF_Operation* func_op = Use({two, func_feed}); 844 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); 845 VerifyFDef( 846 {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), 847 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, 848 {{"scalar", "add_0"}}); 849} 850 851TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { 852 /* 853 * | | scalar 854 * | | . 855 * v v . <---- control dependency 856 * add < - 857 * | 858 * v 859 */ 860 // Define 861 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 862 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 863 TF_Operation* five = ScalarConst(5, func_graph_, s_); 864 TF_Operation* add = 865 AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_); 866 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 867 Define(1, {add}, {feed1, feed2}, {add}, nullptr, true); 868 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 869 EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] " 870 "is not in the body. Encountered while creating " 871 "function 'MyFunc'"), 872 string(TF_Message(s_))); 873} 874 875TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { 876 /* 877 * | |. 878 * | | . 879 * | | . 880 * v v . <---- control dependency 881 * add < - 882 * | 883 * v 884 */ 885 // Define 886 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 887 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 888 TF_Operation* add = 889 AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); 890 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 891 Define(-1, {}, {feed1, feed2}, {add}, nullptr, true); 892 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 893 EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " 894 "is not in the body. Encountered while creating " 895 "function 'MyFunc'"), 896 string(TF_Message(s_))); 897} 898 899TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { 900 /* 901 * feed 902 * | 903 * +++ 904 * | | 905 * +---+-+---+ 906 * | | | | 907 * | v v | 908 * | add | 909 * | | | 910 * | | | 911 * +----+----+ 912 * | 913 * v 914 */ 915 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 916 TF_Operation* add = Add(feed1, feed1, func_graph_, s_); 917 Define(-1, {}, {feed1, feed1}, {add}, nullptr, true); 918 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 919 EXPECT_EQ( 920 string("TF_Output feed1:0 appears more than once in the input list"), 921 string(TF_Message(s_))); 922} 923 924TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) { 925 /* 926 * | | 927 * v v 928 * add 929 * | 930 * v 931 */ 932 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 933 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 934 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 935 DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, nullptr, true); 936 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 937 EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does " 938 "not have output 2\n\tEncountered while processing " 939 "input 1 into function 'MyFunc'"), 940 string(TF_Message(s_))); 941} 942 943TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) { 944 /* 945 * | | 946 * v v 947 * add 948 * | 949 * v 950 */ 951 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 952 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 953 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 954 DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, nullptr, true); 955 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 956 EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 " 957 "into function 'MyFunc'"), 958 string(TF_Message(s_))); 959} 960 961TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) { 962 /* 963 * | | 964 * v v 965 * add 966 * | 967 * v 968 */ 969 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 970 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 971 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 972 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, nullptr, true); 973 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 974 EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does " 975 "not have output 3\n\tEncountered while processing " 976 "output 0 from function 'MyFunc'"), 977 string(TF_Message(s_))); 978} 979 980TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) { 981 /* 982 * | | 983 * v v 984 * add 985 * | 986 * v 987 */ 988 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 989 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 990 Add(feed1, feed2, func_graph_, s_); 991 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, nullptr, true); 992 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 993 EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 " 994 "from function 'MyFunc'"), 995 string(TF_Message(s_))); 996} 997 998TEST_F(CApiFunctionTest, NodeMissingInput) { 999 /* 1000 * input---> | | <----missing input 1001 * v v 1002 * body----> add 1003 * | 1004 * v 1005 */ 1006 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 1007 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 1008 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 1009 DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, nullptr, true); 1010 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 1011 EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' " 1012 "is not available. You might need to include it in inputs " 1013 "or include its source node in the body"), 1014 string(TF_Message(s_))); 1015} 1016 1017TEST_F(CApiFunctionTest, OutputOpNotInBody) { 1018 /* 1019 * | | 1020 * v v 1021 * add scalar (scalar not included in body) 1022 * | | 1023 * v v (function has two outputs) 1024 */ 1025 // Define 1026 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); 1027 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); 1028 TF_Operation* scalar = ScalarConst(2, func_graph_, s_); 1029 TF_Operation* add = Add(feed1, feed2, func_graph_, s_); 1030 Define(1, {add}, {feed1, feed2}, {add, scalar}, nullptr, true); 1031 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 1032 EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor " 1033 "among function inputs. Encountered while creating " 1034 "function 'MyFunc'"), 1035 string(TF_Message(s_))); 1036} 1037 1038} // namespace 1039} // namespace tensorflow 1040