c_api_function_test.cc revision 9624d165f1f2c717eda96464fee8bf7229cc14f5
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/c_api.h"
17
18#include "tensorflow/c/c_test_util.h"
19#include "tensorflow/core/framework/function.pb.h"
20#include "tensorflow/core/framework/op_def.pb.h"
21#include "tensorflow/core/lib/strings/str_util.h"
22#include "tensorflow/core/lib/strings/strcat.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/test.h"
25
26namespace tensorflow {
27namespace {
28
29// Specification for expected input/output and its type.
30// DataType value of DT_INVALID signifies that we don't want to
31// check the data type.
32typedef std::pair<string, DataType> IOSpec;
33
34std::vector<IOSpec> M(const std::initializer_list<string>& names) {
35  std::vector<IOSpec> v;
36  for (const string& name : names) {
37    v.push_back(IOSpec(name, DT_INVALID));
38  }
39  return v;
40}
41
42// Specification for an expected edge.
43// src is either:
44// - input name (as it appears in FunctionDef)
45// - name of output tensor (in nested "add:z:0" format)
46// dst is either:
47// - output name (as it appears in FunctionDef)
48// - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
49//      output tensor naming, but it the index is actually an input index)
50struct EdgeSpec : public std::pair<string, string> {
51  typedef std::pair<string, string> Base;
52
53  // Inherit the set of constructors
54  using Base::pair;
55
56  string ToString() const { return strings::StrCat(first, "->", second); }
57};
58
59class CApiFunctionTest : public ::testing::Test {
60 protected:
61  CApiFunctionTest()
62      : s_(TF_NewStatus()),
63        func_graph_(TF_NewGraph()),
64        host_graph_(TF_NewGraph()),
65        func_(nullptr) {}
66
67  void SetUp() override {}
68
69  ~CApiFunctionTest() override {
70    TF_DeleteFunction(func_);
71    TF_DeleteGraph(host_graph_);
72    TF_DeleteGraph(func_graph_);
73    TF_DeleteStatus(s_);
74  }
75
76  void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
77           TF_Operation* output, int32_t expected_result) {
78    Run(inputs, {{output, 0}}, {expected_result});
79  }
80
81  // Run the host graph, which now contains a function and check that
82  // outputs are as expected.
83  // 'T' stands for 'tensor' since the outputs are tensors, not scalars.
84  void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
85            std::initializer_list<TF_Output> outputs,
86            const std::vector<std::vector<int32_t>>& expected_results) {
87    // Create a session for this graph
88    CSession csession(host_graph_, s_);
89    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
90
91    // Run
92    csession.SetInputs(inputs);
93    csession.SetOutputs(outputs);
94    csession.Run(s_);
95    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
96
97    // Check results
98    for (int i = 0; i < expected_results.size(); ++i) {
99      TF_Tensor* out = csession.output_tensor(i);
100      ASSERT_TRUE(out != nullptr);
101      EXPECT_EQ(TF_INT32, TF_TensorType(out));
102      EXPECT_EQ(1, TF_NumDims(out));
103      CompareInt32Tensor(expected_results[i], out);
104    }
105  }
106
107  // Run the host graph, which now contains a function and check that
108  // outputs are as expected.
109  void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
110           std::initializer_list<TF_Output> outputs,
111           const std::vector<int32_t>& expected_results) {
112    // Create a session for this graph.
113    CSession csession(host_graph_, s_);
114    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
115
116    csession.SetInputs(inputs);
117    csession.SetOutputs(outputs);
118    csession.Run(s_);
119    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
120
121    for (int i = 0; i < expected_results.size(); ++i) {
122      TF_Tensor* out = csession.output_tensor(i);
123      ASSERT_TRUE(out != nullptr);
124      EXPECT_EQ(TF_INT32, TF_TensorType(out));
125      EXPECT_EQ(0, TF_NumDims(out));  // scalar
126      ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
127      int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
128      EXPECT_EQ(expected_results[i], *output_contents);
129    }
130  }
131
132  void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
133    int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
134    size_t size = TF_TensorByteSize(t);
135    ASSERT_EQ(expected.size() * sizeof(int32_t), size);
136    for (int i = 0; i < expected.size(); ++i) {
137      ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
138    }
139  }
140
141  std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
142    std::vector<TF_Output> out;
143    for (auto op : ops) {
144      out.push_back({op, 0});
145    }
146    return out;
147  }
148
149  void Define(int num_opers, const std::vector<TF_Operation*>& opers,
150              const std::vector<TF_Operation*>& inputs,
151              const std::vector<TF_Operation*>& outputs,
152              const char** output_names, bool expect_failure = false) {
153    DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
154            expect_failure);
155  }
156
157  // An explicit `num_opers` is needed so that we can distinguish between the
158  // case of no operations specified (-1) and the case of an empty set of
159  // operations specified (0).
160  void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
161               const std::vector<TF_Output>& inputs,
162               const std::vector<TF_Output>& outputs, const char** output_names,
163               bool expect_failure = false) {
164    ASSERT_EQ(func_, nullptr);
165    func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers,
166                               num_opers == -1 ? nullptr : opers.data(),
167                               inputs.size(), inputs.data(), outputs.size(),
168                               outputs.data(), output_names,
169                               /*opts=*/nullptr, s_);
170    if (expect_failure) {
171      ASSERT_EQ(func_, nullptr);
172      return;
173    }
174
175    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
176    ASSERT_NE(func_, nullptr);
177    TF_GraphAddFunction(host_graph_, func_, s_);
178    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
179  }
180
181  TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
182    return UseT(ToOutput(inputs));
183  }
184
185  TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
186    TF_Operation* op;
187    UseHelper(inputs, &op);
188    return op;
189  }
190
191  // All the *Helper methods are used as a workaround for the restrictions that
192  // one cannot call ASSERT_* methods in non-void-returning functions (when
193  // exceptions are disabled during compilation)
194  void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
195    TF_OperationDescription* desc =
196        TF_NewOperation(host_graph_, func_name_, func_node_name_);
197    for (auto input : inputs) {
198      TF_AddInput(desc, input);
199    }
200    // Set device to CPU because some ops inside the function might not be
201    // available on GPU.
202    TF_SetDevice(desc, "/cpu:0");
203    *op = TF_FinishOperation(desc, s_);
204    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
205    ASSERT_NE(*op, nullptr);
206  }
207
208  FunctionDef fdef() {
209    tensorflow::FunctionDef fdef;
210    EXPECT_TRUE(GetFunctionDef(func_, &fdef));
211    return fdef;
212  }
213
214  // logging utility
215  template <class Container>
216  string ToString(const Container& v) {
217    std::stringstream ss;
218    ss << "{";
219    size_t i = 0;
220    for (const auto& e : v) {
221      if (i != 0) {
222        ss << ", ";
223      }
224      ss << e.ToString();
225      ++i;
226    }
227    ss << "}";
228    return ss.str();
229  }
230
231  void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
232                       const std::unordered_set<string>& nodes) {
233    ASSERT_EQ(nodes.size(), fdef.node_def_size())
234        << "Got unexpected number of nodes. Expected: ["
235        << str_util::Join(nodes, ", ")
236        << "] Actual nodes in fdef: " << fdef.DebugString();
237    for (const NodeDef& node_def : fdef.node_def()) {
238      ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
239          << "Got unexpected node: " << node_def.name()
240          << " in fdef: " << fdef.DebugString();
241    }
242  }
243
244  void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
245                        const std::vector<IOSpec>& inputs) {
246    const OpDef& signature = fdef.signature();
247    ASSERT_EQ(inputs.size(), signature.input_arg_size());
248    for (int i = 0; i < inputs.size(); ++i) {
249      const OpDef::ArgDef& arg = signature.input_arg(i);
250      const IOSpec& in = inputs[i];
251      if (in.second != DT_INVALID) {
252        ASSERT_EQ(arg.type(), in.second)
253            << "Got unexpected type for input " << i
254            << ". fdef: " << fdef.DebugString();
255      }
256      ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
257                                      << ". fdef: " << fdef.DebugString();
258    }
259  }
260
261  void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
262                         const std::vector<IOSpec>& outputs) {
263    const OpDef& signature = fdef.signature();
264    ASSERT_EQ(outputs.size(), signature.output_arg_size());
265    for (int i = 0; i < outputs.size(); ++i) {
266      const OpDef::ArgDef& arg = signature.output_arg(i);
267      const IOSpec& out = outputs[i];
268      if (out.second != DT_INVALID) {
269        ASSERT_EQ(arg.type(), out.second)
270            << "Got unexpected type for output " << i
271            << ". fdef: " << fdef.DebugString();
272      }
273      ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
274                                       << ". fdef: " << fdef.DebugString();
275    }
276  }
277
278  void VerifyFDefEdges(
279      const tensorflow::FunctionDef& fdef,
280      const std::vector<EdgeSpec>& e_edges,  // expected edges
281      const std::vector<EdgeSpec>& c_edges,  // expected ctrl edges
282      bool is_exact_edges = true) {
283    // Build a set of edges from fdef
284    std::set<EdgeSpec> a_edges;  // actual edges
285    // Get edges from inputs to body nodes and between body nodes
286    for (const NodeDef& node_def : fdef.node_def()) {
287      for (int i = 0; i < node_def.input_size(); ++i) {
288        const string& in = node_def.input(i);
289        const auto& v =
290            a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
291        ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
292                              << strings::StrCat(node_def.name(), ":", i)
293                              << ". fdef: " << fdef.DebugString();
294      }
295    }
296    // Get edges from body nodes to outputs and from inputs to outputs
297    for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
298      const auto& iter = fdef.ret().find(arg.name());
299      if (iter != fdef.ret().end()) {
300        const auto& v = a_edges.insert({iter->second, arg.name()});
301        ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
302                              << arg.name() << ". fdef: " << fdef.DebugString();
303      } else {
304        const auto& v = a_edges.insert({arg.name(), arg.name()});
305        ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
306                              << arg.name() << ". fdef: " << fdef.DebugString();
307      }
308    }
309
310    // Verify edges
311    for (const EdgeSpec& e : e_edges) {
312      ASSERT_TRUE(a_edges.find(e) != a_edges.end())
313          << "Failed to find expected edge " << e.ToString()
314          << " in fdef: " << fdef.DebugString();
315    }
316
317    // If caller specified all edges, check that we have seen all
318    if (is_exact_edges) {
319      ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
320          << "Expected edges: " << ToString(e_edges)
321          << " Expected Control edges: " << ToString(c_edges)
322          << " Actual edges: " << ToString(a_edges)
323          << " in fdef: " << fdef.DebugString();
324    }
325  }
326
327  void VerifyFDef(const std::unordered_set<string>& nodes,
328                  const std::vector<IOSpec>& inputs,
329                  const std::vector<IOSpec>& outputs,
330                  const std::vector<EdgeSpec>& e_edges,  // expected edges
331                  const std::vector<EdgeSpec>& c_edges,  // expected ctrl edges
332                  bool is_exact_edges = true) {
333    tensorflow::FunctionDef fdef;
334    ASSERT_TRUE(GetFunctionDef(func_, &fdef));
335    VerifyFDefNodes(fdef, nodes);
336    VerifyFDefInputs(fdef, inputs);
337    VerifyFDefOutputs(fdef, outputs);
338    VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
339  }
340
341  const char* func_name_ = "MyFunc";
342  const char* func_node_name_ = "MyFunc_0";
343  TF_Status* s_;
344  TF_Graph* func_graph_;
345  TF_Graph* host_graph_;
346  TF_Function* func_;
347
348  // Workaround for not being able to initialize empty map using {}
349  std::unordered_set<string> empty_;
350};
351
352TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
353  /*
354   *                constant
355   *                   |
356   *                   v
357   */
358  // Define
359  TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
360  Define(-1, {}, {}, {c}, nullptr);
361
362  // Use, run, and verify
363  TF_Operation* func_op = Use({});
364  Run({}, func_op, 10);
365  VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
366             {{"scalar10_0:output:0", "scalar10"}}, {});
367}
368
369TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
370  /*
371   *                   |
372   *                   v
373   *                 negate
374   *                   |
375   *                   v
376   */
377  // Define
378  TF_Operation* feed = Placeholder(func_graph_, s_);
379  TF_Operation* neg = Neg(feed, func_graph_, s_);
380  Define(-1, {}, {feed}, {neg}, nullptr);
381
382  // Use, run, and verify
383  TF_Operation* func_feed = Placeholder(host_graph_, s_);
384  TF_Operation* func_op = Use({func_feed});
385  Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
386  VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
387             {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
388}
389
390TEST_F(CApiFunctionTest, ZeroOps_Identity) {
391  /*
392   *                   |
393   *                   |
394   *                   |
395   *                   v
396   */
397  // Define
398  TF_Operation* feed = Placeholder(func_graph_, s_);
399  Define(-1, {}, {feed}, {feed}, nullptr);
400
401  // Use, run, and verify
402  TF_Operation* func_feed = Placeholder(host_graph_, s_);
403  TF_Operation* func_op = Use({func_feed});
404  Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
405  VerifyFDef(empty_, {{"feed", DT_INT32}}, {{"feed_0", DT_INT32}},
406             {{"feed", "feed_0"}}, {});
407}
408
409TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
410  /*
411   *                   |   |
412   *                   \  /
413   *                    \/
414   *                    x
415   *                   /\
416   *                  /  \
417   *                 |   |
418   *                 v   v
419   */
420  // Define
421  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
422  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
423  Define(-1, {}, {feed1, feed2}, {feed2, feed1}, nullptr);
424
425  // Use, run, and verify
426  TF_Operation* two = ScalarConst(2, host_graph_, s_);
427  TF_Operation* func_feed = Placeholder(host_graph_, s_);
428  TF_Operation* func_op = Use({two, func_feed});
429  Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
430  VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"feed2_0"}, {"feed1_0"}}),
431             {{"feed1", "feed1_0"}, {"feed2", "feed2_0"}}, {});
432}
433
434TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
435  /*
436   *                  |  |
437   *                  v  v
438   *                  add
439   *                   |
440   *                   v
441   */
442  // Define
443  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
444  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
445  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
446  Define(-1, {}, {feed1, feed2}, {add}, nullptr);
447
448  // Use, run, and verify
449  TF_Operation* two = ScalarConst(2, host_graph_, s_);
450  TF_Operation* func_feed = Placeholder(host_graph_, s_);
451  TF_Operation* func_op = Use({two, func_feed});
452  Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
453  VerifyFDef(
454      {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
455      {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
456}
457
458TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
459  /*
460   *                  |  |
461   *                  v  v
462   *                  add
463   *
464   *            (output ignored)
465   */
466  // Define
467  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
468  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
469  Add(feed1, feed2, func_graph_, s_);
470  Define(-1, {}, {feed1, feed2}, {}, nullptr);
471
472  // Use, run, and verify
473  TF_Operation* two = ScalarConst(2, host_graph_, s_);
474  TF_Operation* func_feed = Placeholder(host_graph_, s_);
475  Use({two, func_feed});
476  VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
477             {{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
478}
479
480TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
481  /*
482   *                  |  |   |
483   *                  v  v   /
484   *                  add1  /
485   *                   |   |
486   *                   v   v
487   *                   add2
488   *                    |
489   *                    v
490   */
491  // Define
492  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
493  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
494  TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
495  TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
496  TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
497  Define(-1, {}, {feed1, feed2, feed3}, {add2}, nullptr);
498
499  // Use, run, and verify
500  TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
501  TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
502  TF_Operation* func_feed = Placeholder(host_graph_, s_);
503  TF_Operation* func_op = Use({two, ten, func_feed});
504  Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
505  VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
506             M({{"add2"}}),
507             {{"feed1", "add1:0"},
508              {"feed2", "add1:1"},
509              {"add1:sum:0", "add2_0:0"},
510              {"feed3", "add2_0:1"},
511              {"add2_0:sum:0", "add2"}},
512             {});
513}
514
515TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
516  /*
517   *                  |  |
518   *                  v  v
519   *                  add
520   *                   |
521   *                 +-+-+
522   *                 |   |
523   *                 v   v
524   */
525  // Define
526  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
527  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
528  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
529  Define(-1, {}, {feed1, feed2}, {add, add}, nullptr);
530
531  // Use, run, and verify
532  TF_Operation* two = ScalarConst(2, host_graph_, s_);
533  TF_Operation* func_feed = Placeholder(host_graph_, s_);
534  TF_Operation* func_op = Use({two, func_feed});
535  Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
536  VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
537             {{"feed1", "add_1:0"},
538              {"feed2", "add_1:1"},
539              {"add_1:sum:0", "add"},
540              {"add_1:sum:0", "add_0"}},
541             {});
542}
543
544TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
545  /*
546   *                  |  |  |
547   *                  v  v  /
548   *                  add  /
549   *                   |  |
550   *                 +-+  |
551   *                 | |  |
552   *                 | v  v
553   *                 | add
554   *                 |  |
555   *                 v  v
556   */
557  // Define
558  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
559  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
560  TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
561  TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
562  TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
563  Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, nullptr);
564
565  // Use, run, and verify
566  TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
567  TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
568  TF_Operation* func_feed = Placeholder(host_graph_, s_);
569  TF_Operation* func_op = Use({two, ten, func_feed});
570  Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
571  VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
572             M({{"add1"}, {"add2"}}),
573             {{"feed1", "add1_0:0"},
574              {"feed2", "add1_0:1"},
575              {"add1_0:sum:0", "add2_0:0"},
576              {"feed3", "add2_0:1"},
577              {"add1_0:sum:0", "add1"},
578              {"add2_0:sum:0", "add2"}},
579             {});
580}
581
582TEST_F(CApiFunctionTest, FromSubsetOfOps) {
583  /*
584   *                  |  |  |
585   *                  v  v  /
586   *                  add  /
587   *                   |  |
588   *               +---+--+---+
589   *  Ops used     |   |  |   |
590   *  for func     |   v  v   |
591   *     |         |   add    |
592   *     +-------> |    |     |
593   *               |    v     |
594   *               |          |
595   *               +----------+
596   */
597  // Define
598  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
599  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
600  TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
601  TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
602  TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
603  Define(1, {add2}, {add1, feed3}, {add2}, nullptr);
604
605  // Use, run, and verify
606  TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
607  TF_Operation* func_feed = Placeholder(host_graph_, s_);
608  TF_Operation* func_op = Use({two, func_feed});
609  Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
610  VerifyFDef(
611      {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
612      {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
613      {});
614}
615
616TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
617  /*
618   *                      feed
619   *                       |
620   *             +---------+---+
621   *             | const0  |   |
622   *             |    |    |   |
623   *             |    v    /   |
624   *             |    split    |
625   *             |   |  |  |   |
626   *             |   v  |  v   |
627   *             |      |      |
628   *             +------+------+
629   *                    |
630   *                    v
631   *
632   *  Only the second output from split is used as function output
633   */
634  // Define
635  TF_Operation* feed = Placeholder(func_graph_, s_);
636  TF_Operation* split = Split3(feed, func_graph_, s_);
637  DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, nullptr);
638
639  // Use, run, and verify
640  TF_Operation* func_feed = Placeholder(host_graph_, s_);
641  TF_Operation* func_op = Use({func_feed});
642  RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
643       {{3, 4}});
644  VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
645             {{"split3_const0:output:0", "split3_0:0"},
646              {"feed", "split3_0:1"},
647              {"split3_0:output:1", "split3"}},
648             {});
649}
650
651TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
652  /*
653   *                      feed
654   *                       |
655   *             +---------+---+
656   *             | const0  |   |
657   *             |    |    |   |
658   *             |    v    /   |
659   *             |    split    |
660   *             |   |  |  |   |
661   *             |   |  v  |   |
662   *             |   |     |   |
663   *             +---+-----+---+
664   *                 |     |
665   *                 v     v
666   *
667   *  Second output from split is not used as function output
668   */
669  // Define
670  TF_Operation* feed = Placeholder(func_graph_, s_);
671  TF_Operation* split = Split3(feed, func_graph_, s_);
672  DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, nullptr);
673
674  // Use, run, and verify
675  TF_Operation* func_feed = Placeholder(host_graph_, s_);
676  TF_Operation* func_op = Use({func_feed});
677  RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
678       {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
679  VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
680             M({{"split3"}, {"split3_0"}}),
681             {{"split3_const0:output:0", "split3_1:0"},
682              {"feed", "split3_1:1"},
683              {"split3_1:output:0", "split3"},
684              {"split3_1:output:2", "split3_0"}},
685             {});
686}
687
688TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
689  /*
690   *                    |
691   *                    v
692   *                  split
693   *                 |  |  |
694   *                 |  v  |
695   *                 |     |
696   *             +---+-----+---+
697   *             |   |     |   |
698   *             |   v     v   |
699   *             |     add     |
700   *             |      |      |
701   *             |      |      |
702   *             +------+------+
703   *                    |
704   *                    v
705   */
706  // Define
707  TF_Operation* feed = Placeholder(func_graph_, s_);
708  TF_Operation* split = Split3(feed, func_graph_, s_);
709  TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
710  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
711  DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr);
712
713  // Use, run, and verify
714  TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
715  TF_Operation* func_feed = Placeholder(host_graph_, s_);
716  TF_Operation* func_op = Use({two, func_feed});
717  Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
718  VerifyFDef(
719      {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
720      {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
721      {});
722}
723
724TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
725  /*
726   *                    |
727   *                    v
728   *                  split
729   *                 |  |  |
730   *                 |  v  |
731   *                 |     |
732   *       input --->|     |<--- input
733   *                 |     |
734   *                 v     v
735   *                   add
736   *                    |
737   *                    |
738   *                    v
739   */
740  // Define
741  TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
742  TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
743  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
744  TF_Operation* split = Split3(c, func_graph_, s_);
745  TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
746  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
747  DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr, true);
748  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
749  EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
750                   "`inputs` must have a single output. Node split3 has "
751                   "3 outputs. Encountered while creating function 'MyFunc'"),
752            string(TF_Message(s_)));
753
754  TF_DeleteTensor(tensor_123);
755}
756
757TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
758  // Inputs to the while loop and the function as a whole
759  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
760  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
761
762  // Outputs of the while loop corresponding to the two inputs above
763  // The first one will the function's output
764  std::vector<TF_Output> outputs;
765
766  // Add while loop to func_graph_
767  {
768    // The inputs to the while loop
769    std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
770    std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
771        TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
772    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
773    params->name = "test_loop";
774
775    // Initialize outputs so we can easily detect errors/bugs
776    outputs.resize(2, {nullptr, -1});
777
778    // Create loop: while (input1 < input2) input1 += input2 + 1
779    TF_Operation* less_than = LessThan(
780        params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
781    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
782    params->cond_output = {less_than, 0};
783
784    TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
785                             params->body_graph, s_, "add1");
786    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
787    TF_Operation* one = ScalarConst(1, params->body_graph, s_);
788    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
789    TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
790    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
791    params->body_outputs[0] = {add2, 0};
792    params->body_outputs[1] = params->body_inputs[1];
793
794    // Finalize while loop
795    TF_FinishWhile(params.get(), s_, &outputs[0]);
796    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
797  }
798
799  // Define function, use it in graph, and run
800  DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, nullptr);
801  TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
802  TF_Operation* func_feed = Placeholder(host_graph_, s_);
803  TF_Operation* func_op = Use({func_feed, five});
804  Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
805
806  // Verify input, output, and subset of edges in fdef.
807  // The subset of edges we verify is a chain between feed1 and output to
808  // make sure that the correct output is picked.
809  tensorflow::FunctionDef fdef;
810  ASSERT_TRUE(GetFunctionDef(func_, &fdef));
811  VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
812  VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
813  VerifyFDefEdges(fdef,
814                  {{"feed1", "test_loop/Enter:0"},
815                   {"test_loop/Enter:output:0", "test_loop/Merge:0"},
816                   {"test_loop/Merge:output:0", "test_loop/Switch:0"},
817                   {"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
818                   {"test_loop/Exit:output:0", "test_loop_exit"}},
819                  {}, false);
820}
821
822TEST_F(CApiFunctionTest, ControlDependency) {
823  /*
824   *                  |  |    scalar
825   *                  |  |    .
826   *                  v  v   . <---- control dependency
827   *                  add < -
828   *                   |
829   *                   v
830   */
831  // Define
832  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
833  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
834  TF_Operation* five = ScalarConst(5, func_graph_, s_);
835  TF_Operation* add =
836      AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
837  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
838  Define(-1, {}, {feed1, feed2}, {add}, nullptr);
839
840  // Use, run, and verify
841  TF_Operation* two = ScalarConst(2, host_graph_, s_);
842  TF_Operation* func_feed = Placeholder(host_graph_, s_);
843  TF_Operation* func_op = Use({two, func_feed});
844  Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
845  VerifyFDef(
846      {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
847      {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
848      {{"scalar", "add_0"}});
849}
850
851TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
852  /*
853   *                  |  |    scalar
854   *                  |  |    .
855   *                  v  v   . <---- control dependency
856   *                  add < -
857   *                   |
858   *                   v
859   */
860  // Define
861  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
862  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
863  TF_Operation* five = ScalarConst(5, func_graph_, s_);
864  TF_Operation* add =
865      AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
866  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
867  Define(1, {add}, {feed1, feed2}, {add}, nullptr, true);
868  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
869  EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
870                   "is not in the body. Encountered while creating "
871                   "function 'MyFunc'"),
872            string(TF_Message(s_)));
873}
874
875TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
876  /*
877   *                  |  |.
878   *                  |  |  .
879   *                  |  |   .
880   *                  v  v   . <---- control dependency
881   *                  add < -
882   *                   |
883   *                   v
884   */
885  // Define
886  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
887  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
888  TF_Operation* add =
889      AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
890  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
891  Define(-1, {}, {feed1, feed2}, {add}, nullptr, true);
892  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
893  EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] "
894                   "is not in the body. Encountered while creating "
895                   "function 'MyFunc'"),
896            string(TF_Message(s_)));
897}
898
899TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
900  /*
901   *                  feed
902   *                   |
903   *                  +++
904   *                  | |
905   *              +---+-+---+
906   *              |   | |   |
907   *              |   v v   |
908   *              |   add   |
909   *              |    |    |
910   *              |    |    |
911   *              +----+----+
912   *                   |
913   *                   v
914   */
915  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
916  TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
917  Define(-1, {}, {feed1, feed1}, {add}, nullptr, true);
918  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
919  EXPECT_EQ(
920      string("TF_Output feed1:0 appears more than once in the input list"),
921      string(TF_Message(s_)));
922}
923
924TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
925  /*
926   *                  |  |
927   *                  v  v
928   *                  add
929   *                   |
930   *                   v
931   */
932  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
933  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
934  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
935  DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, nullptr, true);
936  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
937  EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
938                   "not have output 2\n\tEncountered while processing "
939                   "input 1 into function 'MyFunc'"),
940            string(TF_Message(s_)));
941}
942
943TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
944  /*
945   *                  |  |
946   *                  v  v
947   *                  add
948   *                   |
949   *                   v
950   */
951  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
952  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
953  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
954  DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, nullptr, true);
955  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
956  EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
957                   "into function 'MyFunc'"),
958            string(TF_Message(s_)));
959}
960
961TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
962  /*
963   *                  |  |
964   *                  v  v
965   *                  add
966   *                   |
967   *                   v
968   */
969  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
970  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
971  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
972  DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, nullptr, true);
973  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
974  EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
975                   "not have output 3\n\tEncountered while processing "
976                   "output 0 from function 'MyFunc'"),
977            string(TF_Message(s_)));
978}
979
980TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
981  /*
982   *                  |  |
983   *                  v  v
984   *                  add
985   *                   |
986   *                   v
987   */
988  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
989  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
990  Add(feed1, feed2, func_graph_, s_);
991  DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, nullptr, true);
992  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
993  EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
994                   "from function 'MyFunc'"),
995            string(TF_Message(s_)));
996}
997
998TEST_F(CApiFunctionTest, NodeMissingInput) {
999  /*
1000   *        input---> |  | <----missing input
1001   *                  v  v
1002   *        body----> add
1003   *                   |
1004   *                   v
1005   */
1006  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1007  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1008  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1009  DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, nullptr, true);
1010  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1011  EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
1012                   "is not available. You might need to include it in inputs "
1013                   "or include its source node in the body"),
1014            string(TF_Message(s_)));
1015}
1016
1017TEST_F(CApiFunctionTest, OutputOpNotInBody) {
1018  /*
1019   *                  |  |
1020   *                  v  v
1021   *                  add    scalar    (scalar not included in body)
1022   *                   |       |
1023   *                   v       v       (function has two outputs)
1024   */
1025  // Define
1026  TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1027  TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1028  TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
1029  TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1030  Define(1, {add}, {feed1, feed2}, {add, scalar}, nullptr, true);
1031  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1032  EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
1033                   "among function inputs. Encountered while creating "
1034                   "function 'MyFunc'"),
1035            string(TF_Message(s_)));
1036}
1037
1038}  // namespace
1039}  // namespace tensorflow
1040