1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/c_api.h"
17
18#include <algorithm>
19#include <cstddef>
20#include <iterator>
21#include <memory>
22#include <vector>
23
24#include "tensorflow/c/c_test_util.h"
25#include "tensorflow/cc/saved_model/signature_constants.h"
26#include "tensorflow/cc/saved_model/tag_constants.h"
27#include "tensorflow/core/example/example.pb.h"
28#include "tensorflow/core/example/feature.pb.h"
29#include "tensorflow/core/framework/api_def.pb.h"
30#include "tensorflow/core/framework/common_shape_fns.h"
31#include "tensorflow/core/framework/graph.pb_text.h"
32#include "tensorflow/core/framework/node_def.pb_text.h"
33#include "tensorflow/core/framework/node_def_util.h"
34#include "tensorflow/core/framework/op.h"
35#include "tensorflow/core/framework/partial_tensor_shape.h"
36#include "tensorflow/core/framework/tensor.h"
37#include "tensorflow/core/framework/tensor_shape.pb.h"
38#include "tensorflow/core/framework/types.pb.h"
39#include "tensorflow/core/graph/tensor_id.h"
40#include "tensorflow/core/lib/core/error_codes.pb.h"
41#include "tensorflow/core/lib/core/status_test_util.h"
42#include "tensorflow/core/lib/io/path.h"
43#include "tensorflow/core/lib/strings/str_util.h"
44#include "tensorflow/core/lib/strings/strcat.h"
45#include "tensorflow/core/platform/test.h"
46#include "tensorflow/core/protobuf/meta_graph.pb.h"
47#include "tensorflow/core/util/equal_graph_def.h"
48
49namespace tensorflow {
50TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
51Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
52
53namespace {
54
55static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
56  EXPECT_TRUE(StringPiece(s).contains(expected))
57      << "'" << s << "' does not contain '" << expected << "'";
58}
59
60// Returns the GPU device name if there is one (with arbitrary tie breaking if
61// there are more than one), or "" otherwise.
62string GPUDeviceName(TF_Session* session) {
63  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
64      TF_NewStatus(), TF_DeleteStatus);
65  TF_Status* s = status.get();
66  std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> list(
67      TF_SessionListDevices(session, s), TF_DeleteDeviceList);
68  TF_DeviceList* device_list = list.get();
69
70  CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
71
72  const int num_devices = TF_DeviceListCount(device_list);
73  LOG(INFO) << "There are " << num_devices << " devices.";
74  for (int i = 0; i < num_devices; ++i) {
75    const char* device_name = TF_DeviceListName(device_list, i, s);
76    CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
77    const char* device_type = TF_DeviceListType(device_list, i, s);
78    CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
79    LOG(INFO) << "Device " << i << " has name " << device_name << ", type "
80              << device_type;
81    if (string(device_type) == DEVICE_GPU) {
82      return device_name;
83    }
84  }
85  // No GPU device found.
86  return "";
87}
88
89string GPUDeviceName() {
90  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
91      TF_NewStatus(), TF_DeleteStatus);
92  TF_Status* s = status.get();
93  std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph(TF_NewGraph(),
94                                                             TF_DeleteGraph);
95
96  TF_SessionOptions* opts = TF_NewSessionOptions();
97  TF_Session* sess = TF_NewSession(graph.get(), opts, s);
98  TF_DeleteSessionOptions(opts);
99
100  const string gpu_device_name = GPUDeviceName(sess);
101  TF_DeleteSession(sess, s);
102  CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
103  return gpu_device_name;
104}
105
106TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); }
107
108TEST(CAPI, Status) {
109  TF_Status* s = TF_NewStatus();
110  EXPECT_EQ(TF_OK, TF_GetCode(s));
111  EXPECT_EQ(string(), TF_Message(s));
112  TF_SetStatus(s, TF_CANCELLED, "cancel");
113  EXPECT_EQ(TF_CANCELLED, TF_GetCode(s));
114  EXPECT_EQ(string("cancel"), TF_Message(s));
115  TF_DeleteStatus(s);
116}
117
118void Deallocator(void* data, size_t, void* arg) {
119  tensorflow::cpu_allocator()->DeallocateRaw(data);
120  *reinterpret_cast<bool*>(arg) = true;
121}
122
123TEST(CAPI, Tensor) {
124  const int num_bytes = 6 * sizeof(float);
125  float* values =
126      reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
127          EIGEN_MAX_ALIGN_BYTES, num_bytes));
128  int64_t dims[] = {2, 3};
129  bool deallocator_called = false;
130  TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
131                              &Deallocator, &deallocator_called);
132  EXPECT_FALSE(deallocator_called);
133  EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
134  EXPECT_EQ(2, TF_NumDims(t));
135  EXPECT_EQ(dims[0], TF_Dim(t, 0));
136  EXPECT_EQ(dims[1], TF_Dim(t, 1));
137  EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
138  EXPECT_EQ(static_cast<void*>(values), TF_TensorData(t));
139  TF_DeleteTensor(t);
140  EXPECT_TRUE(deallocator_called);
141}
142
143void NoOpDeallocator(void* data, size_t, void*) {}
144
145TEST(CAPI, MalformedTensor) {
146  // See https://github.com/tensorflow/tensorflow/issues/7394
147  // num_dims = 0 implies a scalar, so should be backed by at least 4 bytes of
148  // data.
149  TF_Tensor* t =
150      TF_NewTensor(TF_FLOAT, nullptr, 0, nullptr, 0, &NoOpDeallocator, nullptr);
151  ASSERT_TRUE(t == nullptr);
152}
153
154TEST(CAPI, AllocateTensor) {
155  const int num_bytes = 6 * sizeof(float);
156  int64_t dims[] = {2, 3};
157  TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes);
158  EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
159  EXPECT_EQ(2, TF_NumDims(t));
160  EXPECT_EQ(dims[0], TF_Dim(t, 0));
161  EXPECT_EQ(dims[1], TF_Dim(t, 1));
162  EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
163  TF_DeleteTensor(t);
164}
165
166TEST(CAPI, MaybeMove) {
167  const int num_bytes = 6 * sizeof(float);
168  float* values =
169      reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
170          EIGEN_MAX_ALIGN_BYTES, num_bytes));
171  int64_t dims[] = {2, 3};
172  bool deallocator_called = false;
173  TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
174                              &Deallocator, &deallocator_called);
175
176  TF_Tensor* o = TF_TensorMaybeMove(t);
177  ASSERT_TRUE(o == nullptr);  // It is unsafe to move memory TF might not own.
178  TF_DeleteTensor(t);
179  EXPECT_TRUE(deallocator_called);
180}
181
182TEST(CAPI, LibraryLoadFunctions) {
183  // TODO(b/73318067): Fix linking for the GPU test generated by the
184  // tf_cuda_cc_test() bazel rule and remove the next line.
185  if (!GPUDeviceName().empty()) return;
186
187  // Load the library.
188  TF_Status* status = TF_NewStatus();
189  TF_Library* lib =
190      TF_LoadLibrary("tensorflow/c/test_op.so", status);
191  TF_Code code = TF_GetCode(status);
192  string status_msg(TF_Message(status));
193  TF_DeleteStatus(status);
194  ASSERT_EQ(TF_OK, code) << status_msg;
195
196  // Test op list.
197  TF_Buffer op_list_buf = TF_GetOpList(lib);
198  tensorflow::OpList op_list;
199  EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
200  ASSERT_EQ(op_list.op_size(), 1);
201  EXPECT_EQ("TestCApi", op_list.op(0).name());
202
203  TF_DeleteLibraryHandle(lib);
204}
205
206void TestEncodeDecode(int line, const std::vector<string>& data) {
207  const tensorflow::int64 n = data.size();
208  TF_Status* status = TF_NewStatus();
209  for (const std::vector<tensorflow::int64>& dims :
210       std::vector<std::vector<tensorflow::int64>>{
211           {n}, {1, n}, {n, 1}, {n / 2, 2}}) {
212    // Create C++ Tensor
213    Tensor src(tensorflow::DT_STRING, TensorShape(dims));
214    for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
215      src.flat<string>()(i) = data[i];
216    }
217    TF_Tensor* dst = TF_TensorFromTensor(src, status);
218    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
219
220    // Convert back to a C++ Tensor and ensure we get expected output.
221    Tensor output;
222    ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line;
223    ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
224    for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
225      ASSERT_EQ(data[i], output.flat<string>()(i)) << line;
226    }
227
228    TF_DeleteTensor(dst);
229  }
230  TF_DeleteStatus(status);
231}
232
233TEST(CAPI, TensorEncodeDecodeStrings) {
234  TestEncodeDecode(__LINE__, {});
235  TestEncodeDecode(__LINE__, {"hello"});
236  TestEncodeDecode(__LINE__,
237                   {"the", "quick", "brown", "fox", "jumped", "over"});
238
239  string big(1000, 'a');
240  TestEncodeDecode(__LINE__, {"small", big, "small2"});
241}
242
243TEST(CAPI, SessionOptions) {
244  TF_SessionOptions* opt = TF_NewSessionOptions();
245  TF_DeleteSessionOptions(opt);
246}
247
248TEST(CAPI, DeprecatedSession) {
249  TF_Status* s = TF_NewStatus();
250  TF_SessionOptions* opt = TF_NewSessionOptions();
251  TF_DeprecatedSession* session = TF_NewDeprecatedSession(opt, s);
252  TF_DeleteSessionOptions(opt);
253  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
254
255  TF_Buffer* run_options = TF_NewBufferFromString("", 0);
256  TF_Buffer* run_metadata = TF_NewBuffer();
257  TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0,
258         nullptr, 0, run_metadata, s);
259  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
260  EXPECT_EQ(std::string("Session was not created with a graph before Run()!"),
261            std::string(TF_Message(s)));
262  TF_DeleteBuffer(run_metadata);
263  TF_DeleteBuffer(run_options);
264
265  TF_DeleteDeprecatedSession(session, s);
266  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
267
268  TF_DeleteStatus(s);
269}
270
271TEST(CAPI, DataTypeEnum) {
272  EXPECT_EQ(TF_FLOAT, static_cast<TF_DataType>(tensorflow::DT_FLOAT));
273  EXPECT_EQ(TF_DOUBLE, static_cast<TF_DataType>(tensorflow::DT_DOUBLE));
274  EXPECT_EQ(TF_INT32, static_cast<TF_DataType>(tensorflow::DT_INT32));
275  EXPECT_EQ(TF_UINT8, static_cast<TF_DataType>(tensorflow::DT_UINT8));
276  EXPECT_EQ(TF_INT16, static_cast<TF_DataType>(tensorflow::DT_INT16));
277  EXPECT_EQ(TF_INT8, static_cast<TF_DataType>(tensorflow::DT_INT8));
278  EXPECT_EQ(TF_STRING, static_cast<TF_DataType>(tensorflow::DT_STRING));
279  EXPECT_EQ(TF_COMPLEX64, static_cast<TF_DataType>(tensorflow::DT_COMPLEX64));
280  EXPECT_EQ(TF_COMPLEX, TF_COMPLEX64);
281  EXPECT_EQ(TF_INT64, static_cast<TF_DataType>(tensorflow::DT_INT64));
282  EXPECT_EQ(TF_BOOL, static_cast<TF_DataType>(tensorflow::DT_BOOL));
283  EXPECT_EQ(TF_QINT8, static_cast<TF_DataType>(tensorflow::DT_QINT8));
284  EXPECT_EQ(TF_QUINT8, static_cast<TF_DataType>(tensorflow::DT_QUINT8));
285  EXPECT_EQ(TF_QINT32, static_cast<TF_DataType>(tensorflow::DT_QINT32));
286  EXPECT_EQ(TF_BFLOAT16, static_cast<TF_DataType>(tensorflow::DT_BFLOAT16));
287  EXPECT_EQ(TF_QINT16, static_cast<TF_DataType>(tensorflow::DT_QINT16));
288  EXPECT_EQ(TF_QUINT16, static_cast<TF_DataType>(tensorflow::DT_QUINT16));
289  EXPECT_EQ(TF_UINT16, static_cast<TF_DataType>(tensorflow::DT_UINT16));
290  EXPECT_EQ(TF_COMPLEX128, static_cast<TF_DataType>(tensorflow::DT_COMPLEX128));
291  EXPECT_EQ(TF_HALF, static_cast<TF_DataType>(tensorflow::DT_HALF));
292  EXPECT_EQ(TF_DataTypeSize(TF_DOUBLE),
293            tensorflow::DataTypeSize(tensorflow::DT_DOUBLE));
294  EXPECT_EQ(TF_DataTypeSize(TF_STRING),
295            tensorflow::DataTypeSize(tensorflow::DT_STRING));
296  // Test with invalid type; should always return 0 as documented
297  EXPECT_EQ(TF_DataTypeSize(static_cast<TF_DataType>(0)), 0);
298}
299
300TEST(CAPI, StatusEnum) {
301  EXPECT_EQ(TF_OK, static_cast<TF_Code>(tensorflow::error::OK));
302  EXPECT_EQ(TF_CANCELLED, static_cast<TF_Code>(tensorflow::error::CANCELLED));
303  EXPECT_EQ(TF_UNKNOWN, static_cast<TF_Code>(tensorflow::error::UNKNOWN));
304  EXPECT_EQ(TF_INVALID_ARGUMENT,
305            static_cast<TF_Code>(tensorflow::error::INVALID_ARGUMENT));
306  EXPECT_EQ(TF_DEADLINE_EXCEEDED,
307            static_cast<TF_Code>(tensorflow::error::DEADLINE_EXCEEDED));
308  EXPECT_EQ(TF_NOT_FOUND, static_cast<TF_Code>(tensorflow::error::NOT_FOUND));
309  EXPECT_EQ(TF_ALREADY_EXISTS,
310            static_cast<TF_Code>(tensorflow::error::ALREADY_EXISTS));
311  EXPECT_EQ(TF_PERMISSION_DENIED,
312            static_cast<TF_Code>(tensorflow::error::PERMISSION_DENIED));
313  EXPECT_EQ(TF_UNAUTHENTICATED,
314            static_cast<TF_Code>(tensorflow::error::UNAUTHENTICATED));
315  EXPECT_EQ(TF_RESOURCE_EXHAUSTED,
316            static_cast<TF_Code>(tensorflow::error::RESOURCE_EXHAUSTED));
317  EXPECT_EQ(TF_FAILED_PRECONDITION,
318            static_cast<TF_Code>(tensorflow::error::FAILED_PRECONDITION));
319  EXPECT_EQ(TF_ABORTED, static_cast<TF_Code>(tensorflow::error::ABORTED));
320  EXPECT_EQ(TF_OUT_OF_RANGE,
321            static_cast<TF_Code>(tensorflow::error::OUT_OF_RANGE));
322  EXPECT_EQ(TF_UNIMPLEMENTED,
323            static_cast<TF_Code>(tensorflow::error::UNIMPLEMENTED));
324  EXPECT_EQ(TF_INTERNAL, static_cast<TF_Code>(tensorflow::error::INTERNAL));
325  EXPECT_EQ(TF_UNAVAILABLE,
326            static_cast<TF_Code>(tensorflow::error::UNAVAILABLE));
327  EXPECT_EQ(TF_DATA_LOSS, static_cast<TF_Code>(tensorflow::error::DATA_LOSS));
328}
329
330TEST(CAPI, GetAllOpList) {
331  TF_Buffer* buf = TF_GetAllOpList();
332  tensorflow::OpList op_list;
333  EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length));
334  EXPECT_GT(op_list.op_size(), 0);
335  TF_DeleteBuffer(buf);
336}
337
338TEST(CAPI, SetShape) {
339  TF_Status* s = TF_NewStatus();
340  TF_Graph* graph = TF_NewGraph();
341
342  TF_Operation* feed = Placeholder(graph, s);
343  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
344  TF_Output feed_out_0 = TF_Output{feed, 0};
345  int num_dims;
346
347  // Fetch the shape, it should be completely unknown.
348  num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
349  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
350  EXPECT_EQ(-1, num_dims);
351
352  // Set the shape to be unknown, expect no change.
353  TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
354  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
355  num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
356  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
357  EXPECT_EQ(-1, num_dims);
358
359  // Set the shape to be 2 x Unknown
360  int64_t dims[] = {2, -1};
361  TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
362  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
363
364  // Fetch the shape and validate it is 2 by -1.
365  num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
366  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
367  EXPECT_EQ(2, num_dims);
368
369  // Resize the dimension vector appropriately.
370  int64_t returned_dims[2];
371  TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
372  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
373  EXPECT_EQ(dims[0], returned_dims[0]);
374  EXPECT_EQ(dims[1], returned_dims[1]);
375
376  // Set to a new valid shape: [2, 3]
377  dims[1] = 3;
378  TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
379  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
380
381  // Fetch and see that the new value is returned.
382  TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
383  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
384  EXPECT_EQ(dims[0], returned_dims[0]);
385  EXPECT_EQ(dims[1], returned_dims[1]);
386
387  // Try to set 'unknown' with unknown rank on the shape and see that
388  // it doesn't change.
389  TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
390  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
391  TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
392  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
393  EXPECT_EQ(2, num_dims);
394  EXPECT_EQ(2, returned_dims[0]);
395  EXPECT_EQ(3, returned_dims[1]);
396
397  // Try to set 'unknown' with same rank on the shape and see that
398  // it doesn't change.
399  dims[0] = -1;
400  dims[1] = -1;
401  TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
402  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
403  // Fetch and see that the new value is returned.
404  TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
405  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
406  EXPECT_EQ(2, num_dims);
407  EXPECT_EQ(2, returned_dims[0]);
408  EXPECT_EQ(3, returned_dims[1]);
409
410  // Try to fetch a shape with the wrong num_dims
411  TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
412  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
413
414  // Try to set an invalid shape (cannot change 2x3 to a 2x5).
415  dims[1] = 5;
416  TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
417  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
418
419  // Test for a scalar.
420  TF_Operation* three = ScalarConst(3, graph, s);
421  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
422  TF_Output three_out_0 = TF_Output{three, 0};
423
424  num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s);
425  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
426  EXPECT_EQ(0, num_dims);
427  TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s);
428  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
429
430  // Clean up
431  TF_DeleteGraph(graph);
432  TF_DeleteStatus(s);
433}
434
435TEST(CAPI, Graph) {
436  TF_Status* s = TF_NewStatus();
437  TF_Graph* graph = TF_NewGraph();
438
439  // Make a placeholder operation.
440  TF_Operation* feed = Placeholder(graph, s);
441  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
442
443  // Test TF_Operation*() query functions.
444  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
445  EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed)));
446  EXPECT_EQ(string(""), string(TF_OperationDevice(feed)));
447  EXPECT_EQ(1, TF_OperationNumOutputs(feed));
448  EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{feed, 0}));
449  EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s));
450  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
451  EXPECT_EQ(0, TF_OperationNumInputs(feed));
452  EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
453  EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
454  EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
455
456  tensorflow::AttrValue attr_value;
457  ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s);
458  EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
459
460  // Test not found errors in TF_Operation*() query functions.
461  EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s));
462  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
463
464  ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
465  EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."),
466            string(TF_Message(s)));
467
468  // Make a constant oper with the scalar "3".
469  TF_Operation* three = ScalarConst(3, graph, s);
470  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
471
472  // Add oper.
473  TF_Operation* add = Add(feed, three, graph, s);
474  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
475
476  // Test TF_Operation*() query functions.
477  EXPECT_EQ(string("add"), string(TF_OperationName(add)));
478  EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add)));
479  EXPECT_EQ(string(""), string(TF_OperationDevice(add)));
480  EXPECT_EQ(1, TF_OperationNumOutputs(add));
481  EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{add, 0}));
482  EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s));
483  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
484  EXPECT_EQ(2, TF_OperationNumInputs(add));
485  EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s));
486  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
487  EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 0}));
488  EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 1}));
489  TF_Output add_in_0 = TF_OperationInput(TF_Input{add, 0});
490  EXPECT_EQ(feed, add_in_0.oper);
491  EXPECT_EQ(0, add_in_0.index);
492  TF_Output add_in_1 = TF_OperationInput(TF_Input{add, 1});
493  EXPECT_EQ(three, add_in_1.oper);
494  EXPECT_EQ(0, add_in_1.index);
495  EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{add, 0}));
496  EXPECT_EQ(0, TF_OperationNumControlInputs(add));
497  EXPECT_EQ(0, TF_OperationNumControlOutputs(add));
498
499  ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s);
500  EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
501  ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s);
502  EXPECT_EQ(attr_value.i(), 2);
503
504  // Placeholder oper now has a consumer.
505  ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
506  TF_Input feed_port;
507  EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Output{feed, 0}, &feed_port, 1));
508  EXPECT_EQ(add, feed_port.oper);
509  EXPECT_EQ(0, feed_port.index);
510
511  // The scalar const oper also has a consumer.
512  ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{three, 0}));
513  TF_Input three_port;
514  EXPECT_EQ(1,
515            TF_OperationOutputConsumers(TF_Output{three, 0}, &three_port, 1));
516  EXPECT_EQ(add, three_port.oper);
517  EXPECT_EQ(1, three_port.index);
518
519  // Serialize to GraphDef.
520  GraphDef graph_def;
521  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
522
523  // Validate GraphDef is what we expect.
524  bool found_placeholder = false;
525  bool found_scalar_const = false;
526  bool found_add = false;
527  for (const auto& n : graph_def.node()) {
528    if (IsPlaceholder(n)) {
529      EXPECT_FALSE(found_placeholder);
530      found_placeholder = true;
531    } else if (IsScalarConst(n, 3)) {
532      EXPECT_FALSE(found_scalar_const);
533      found_scalar_const = true;
534    } else if (IsAddN(n, 2)) {
535      EXPECT_FALSE(found_add);
536      found_add = true;
537    } else {
538      ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
539    }
540  }
541  EXPECT_TRUE(found_placeholder);
542  EXPECT_TRUE(found_scalar_const);
543  EXPECT_TRUE(found_add);
544
545  // Add another oper to the graph.
546  TF_Operation* neg = Neg(add, graph, s);
547  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
548
549  // Serialize to NodeDef.
550  NodeDef node_def;
551  ASSERT_TRUE(GetNodeDef(neg, &node_def));
552
553  // Validate NodeDef is what we expect.
554  EXPECT_TRUE(IsNeg(node_def, "add"));
555
556  // Serialize to GraphDef.
557  GraphDef graph_def2;
558  ASSERT_TRUE(GetGraphDef(graph, &graph_def2));
559
560  // Compare with first GraphDef + added NodeDef.
561  NodeDef* added_node = graph_def.add_node();
562  *added_node = node_def;
563  EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2));
564
565  // Look up some nodes by name.
566  TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
567  EXPECT_TRUE(neg == neg2);
568  NodeDef node_def2;
569  ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
570  EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2));
571
572  TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
573  EXPECT_TRUE(feed == feed2);
574  ASSERT_TRUE(GetNodeDef(feed, &node_def));
575  ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
576  EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2));
577
578  // Test iterating through the nodes of a graph.
579  found_placeholder = false;
580  found_scalar_const = false;
581  found_add = false;
582  bool found_neg = false;
583  size_t pos = 0;
584  TF_Operation* oper;
585  while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
586    if (oper == feed) {
587      EXPECT_FALSE(found_placeholder);
588      found_placeholder = true;
589    } else if (oper == three) {
590      EXPECT_FALSE(found_scalar_const);
591      found_scalar_const = true;
592    } else if (oper == add) {
593      EXPECT_FALSE(found_add);
594      found_add = true;
595    } else if (oper == neg) {
596      EXPECT_FALSE(found_neg);
597      found_neg = true;
598    } else {
599      ASSERT_TRUE(GetNodeDef(oper, &node_def));
600      ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def);
601    }
602  }
603  EXPECT_TRUE(found_placeholder);
604  EXPECT_TRUE(found_scalar_const);
605  EXPECT_TRUE(found_add);
606  EXPECT_TRUE(found_neg);
607
608  // Clean up
609  TF_DeleteGraph(graph);
610  TF_DeleteStatus(s);
611}
612
613/*
614TODO(skyewm): this test currently DCHECKs, change to bad status
615
616TEST(CAPI, InputFromDifferentGraphError) {
617  TF_Status* s = TF_NewStatus();
618  TF_Graph* g1 = TF_NewGraph();
619  TF_Graph* g2 = TF_NewGraph();
620
621  TF_Operation* feed = Placeholder(g1, s);
622  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
623
624  // Attempt to create node in g2 with input from g1
625  Neg(feed, g2, s);
626  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
627  EXPECT_STREQ("foo", TF_Message(s));
628
629  TF_DeleteGraph(g1);
630  TF_DeleteGraph(g2);
631  TF_DeleteStatus(s);
632}
633*/
634
635TEST(CAPI, ImportGraphDef) {
636  TF_Status* s = TF_NewStatus();
637  TF_Graph* graph = TF_NewGraph();
638
639  // Create a simple graph.
640  Placeholder(graph, s);
641  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
642  ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
643  TF_Operation* oper = ScalarConst(3, graph, s);
644  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
645  ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
646  Neg(oper, graph, s);
647  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
648  ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
649
650  // Export to a GraphDef.
651  TF_Buffer* graph_def = TF_NewBuffer();
652  TF_GraphToGraphDef(graph, graph_def, s);
653  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
654
655  // Import it, with a prefix, in a fresh graph.
656  TF_DeleteGraph(graph);
657  graph = TF_NewGraph();
658  TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
659  TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
660  TF_GraphImportGraphDef(graph, graph_def, opts, s);
661  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
662
663  TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
664  TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
665  TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg");
666  ASSERT_TRUE(scalar != nullptr);
667  ASSERT_TRUE(feed != nullptr);
668  ASSERT_TRUE(neg != nullptr);
669
670  // Test basic structure of the imported graph.
671  EXPECT_EQ(0, TF_OperationNumInputs(scalar));
672  EXPECT_EQ(0, TF_OperationNumInputs(feed));
673  ASSERT_EQ(1, TF_OperationNumInputs(neg));
674  TF_Output neg_input = TF_OperationInput({neg, 0});
675  EXPECT_EQ(scalar, neg_input.oper);
676  EXPECT_EQ(0, neg_input.index);
677
678  // Test that we can't see control edges involving the source and sink nodes.
679  TF_Operation* control_ops[100];
680  EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
681  EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
682  EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
683  EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
684
685  EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
686  EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
687  EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
688  EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
689
690  EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
691  EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
692  EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
693  EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
694
695  // Import it again, with an input mapping, return outputs, and a return
696  // operation, into the same graph.
697  TF_DeleteImportGraphDefOptions(opts);
698  opts = TF_NewImportGraphDefOptions();
699  TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
700  TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
701  TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
702  TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
703  EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
704  TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
705  EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
706  TF_ImportGraphDefResults* results =
707      TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
708  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
709
710  TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
711  TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed");
712  TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg");
713  ASSERT_TRUE(scalar2 != nullptr);
714  ASSERT_TRUE(feed2 != nullptr);
715  ASSERT_TRUE(neg2 != nullptr);
716
717  // Check input mapping
718  neg_input = TF_OperationInput({neg, 0});
719  EXPECT_EQ(scalar, neg_input.oper);
720  EXPECT_EQ(0, neg_input.index);
721
722  // Check return outputs
723  TF_Output* return_outputs;
724  int num_return_outputs;
725  TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
726                                        &return_outputs);
727  ASSERT_EQ(2, num_return_outputs);
728  EXPECT_EQ(feed2, return_outputs[0].oper);
729  EXPECT_EQ(0, return_outputs[0].index);
730  EXPECT_EQ(scalar, return_outputs[1].oper);  // remapped
731  EXPECT_EQ(0, return_outputs[1].index);
732
733  // Check return operation
734  TF_Operation** return_opers;
735  int num_return_opers;
736  TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
737                                           &return_opers);
738  ASSERT_EQ(1, num_return_opers);
739  EXPECT_EQ(scalar2, return_opers[0]);  // not remapped
740
741  TF_DeleteImportGraphDefResults(results);
742
743  // Import again, with control dependencies, into the same graph.
744  TF_DeleteImportGraphDefOptions(opts);
745  opts = TF_NewImportGraphDefOptions();
746  TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
747  TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
748  TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
749  TF_GraphImportGraphDef(graph, graph_def, opts, s);
750  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
751
752  TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar");
753  TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed");
754  TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg");
755  ASSERT_TRUE(scalar3 != nullptr);
756  ASSERT_TRUE(feed3 != nullptr);
757  ASSERT_TRUE(neg3 != nullptr);
758
759  // Check that newly-imported scalar and feed have control deps (neg3 will
760  // inherit them from input)
761  TF_Operation* control_inputs[100];
762  int num_control_inputs = TF_OperationGetControlInputs(
763      scalar3, control_inputs, TF_OperationNumControlInputs(scalar3));
764  ASSERT_EQ(2, num_control_inputs);
765  EXPECT_EQ(feed, control_inputs[0]);
766  EXPECT_EQ(feed2, control_inputs[1]);
767
768  num_control_inputs = TF_OperationGetControlInputs(
769      feed3, control_inputs, TF_OperationNumControlInputs(feed3));
770  ASSERT_EQ(2, num_control_inputs);
771  EXPECT_EQ(feed, control_inputs[0]);
772  EXPECT_EQ(feed2, control_inputs[1]);
773
774  // Export to a graph def so we can import a graph with control dependencies
775  TF_DeleteBuffer(graph_def);
776  graph_def = TF_NewBuffer();
777  TF_GraphToGraphDef(graph, graph_def, s);
778  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
779
780  // Import again, with remapped control dependency, into the same graph
781  TF_DeleteImportGraphDefOptions(opts);
782  opts = TF_NewImportGraphDefOptions();
783  TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
784  TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
785  TF_GraphImportGraphDef(graph, graph_def, opts, s);
786  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
787
788  TF_Operation* scalar4 =
789      TF_GraphOperationByName(graph, "imported4/imported3/scalar");
790  TF_Operation* feed4 =
791      TF_GraphOperationByName(graph, "imported4/imported2/feed");
792
793  // Check that imported `imported3/scalar` has remapped control dep from
794  // original graph and imported control dep
795  num_control_inputs = TF_OperationGetControlInputs(
796      scalar4, control_inputs, TF_OperationNumControlInputs(scalar4));
797  ASSERT_EQ(2, num_control_inputs);
798  EXPECT_EQ(feed, control_inputs[0]);
799  EXPECT_EQ(feed4, control_inputs[1]);
800
801  TF_DeleteImportGraphDefOptions(opts);
802  TF_DeleteBuffer(graph_def);
803
804  // Can add nodes to the imported graph without trouble.
805  Add(feed, scalar, graph, s);
806  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
807
808  TF_DeleteGraph(graph);
809  TF_DeleteStatus(s);
810}
811
812TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
813  TF_Status* s = TF_NewStatus();
814  TF_Graph* graph = TF_NewGraph();
815
816  // Create a graph with two nodes: x and 3
817  Placeholder(graph, s);
818  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
819  ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
820  TF_Operation* oper = ScalarConst(3, graph, s);
821  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
822  ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
823  Neg(oper, graph, s);
824  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
825  ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
826
827  // Export to a GraphDef.
828  TF_Buffer* graph_def = TF_NewBuffer();
829  TF_GraphToGraphDef(graph, graph_def, s);
830  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
831
832  // Import it in a fresh graph with return outputs.
833  TF_DeleteGraph(graph);
834  graph = TF_NewGraph();
835  TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
836  TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
837  TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
838  EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
839  TF_Output return_outputs[2];
840  TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
841                                          return_outputs, 2, s);
842  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
843
844  TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
845  TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
846  TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
847  ASSERT_TRUE(scalar != nullptr);
848  ASSERT_TRUE(feed != nullptr);
849  ASSERT_TRUE(neg != nullptr);
850
851  // Check return outputs
852  EXPECT_EQ(feed, return_outputs[0].oper);
853  EXPECT_EQ(0, return_outputs[0].index);
854  EXPECT_EQ(scalar, return_outputs[1].oper);
855  EXPECT_EQ(0, return_outputs[1].index);
856
857  TF_DeleteImportGraphDefOptions(opts);
858  TF_DeleteBuffer(graph_def);
859  TF_DeleteGraph(graph);
860  TF_DeleteStatus(s);
861}
862
863TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) {
864  TF_Status* s = TF_NewStatus();
865  TF_Graph* graph = TF_NewGraph();
866
867  // Create a graph with two nodes: x and 3
868  Placeholder(graph, s);
869  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
870  ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
871  TF_Operation* oper = ScalarConst(3, graph, s);
872  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
873  ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
874  Neg(oper, graph, s);
875  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
876  ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
877
878  // Export to a GraphDef.
879  TF_Buffer* graph_def = TF_NewBuffer();
880  TF_GraphToGraphDef(graph, graph_def, s);
881  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
882
883  // Import it in a fresh graph.
884  TF_DeleteGraph(graph);
885  graph = TF_NewGraph();
886  TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
887  TF_GraphImportGraphDef(graph, graph_def, opts, s);
888  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
889
890  TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
891
892  // Import it in a fresh graph with an unused input mapping.
893  TF_DeleteImportGraphDefOptions(opts);
894  opts = TF_NewImportGraphDefOptions();
895  TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
896  TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
897  TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
898  TF_ImportGraphDefResults* results =
899      TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
900  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
901
902  // Check unused input mappings
903  int num_unused_input_mappings;
904  const char** src_names;
905  int* src_indexes;
906  TF_ImportGraphDefResultsMissingUnusedInputMappings(
907      results, &num_unused_input_mappings, &src_names, &src_indexes);
908  ASSERT_EQ(1, num_unused_input_mappings);
909  EXPECT_EQ(string("fake"), string(src_names[0]));
910  EXPECT_EQ(0, src_indexes[0]);
911
912  TF_DeleteImportGraphDefResults(results);
913  TF_DeleteImportGraphDefOptions(opts);
914  TF_DeleteBuffer(graph_def);
915  TF_DeleteGraph(graph);
916  TF_DeleteStatus(s);
917}
918
919TEST(CAPI, Session) {
920  TF_Status* s = TF_NewStatus();
921  TF_Graph* graph = TF_NewGraph();
922
923  // Make a placeholder operation.
924  TF_Operation* feed = Placeholder(graph, s);
925  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
926
927  // Make a constant operation with the scalar "2".
928  TF_Operation* two = ScalarConst(2, graph, s);
929  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
930
931  // Add operation.
932  TF_Operation* add = Add(feed, two, graph, s);
933  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
934
935  // Create a session for this graph.
936  CSession csession(graph, s);
937  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
938
939  // Run the graph.
940  csession.SetInputs({{feed, Int32Tensor(3)}});
941  csession.SetOutputs({add});
942  csession.Run(s);
943  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
944  TF_Tensor* out = csession.output_tensor(0);
945  ASSERT_TRUE(out != nullptr);
946  EXPECT_EQ(TF_INT32, TF_TensorType(out));
947  EXPECT_EQ(0, TF_NumDims(out));  // scalar
948  ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
949  int32* output_contents = static_cast<int32*>(TF_TensorData(out));
950  EXPECT_EQ(3 + 2, *output_contents);
951
952  // Add another operation to the graph.
953  TF_Operation* neg = Neg(add, graph, s);
954  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
955
956  // Run up to the new operation.
957  csession.SetInputs({{feed, Int32Tensor(7)}});
958  csession.SetOutputs({neg});
959  csession.Run(s);
960  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
961  out = csession.output_tensor(0);
962  ASSERT_TRUE(out != nullptr);
963  EXPECT_EQ(TF_INT32, TF_TensorType(out));
964  EXPECT_EQ(0, TF_NumDims(out));  // scalar
965  ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
966  output_contents = static_cast<int32*>(TF_TensorData(out));
967  EXPECT_EQ(-(7 + 2), *output_contents);
968
969  // Clean up
970  csession.CloseAndDelete(s);
971  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
972  TF_DeleteGraph(graph);
973  TF_DeleteStatus(s);
974}
975
976// If `device` is non-empty, run Min op on that device.
977// Otherwise run it on the default device (CPU).
978void RunMinTest(const string& device, bool use_XLA) {
979  TF_Status* s = TF_NewStatus();
980  TF_Graph* graph = TF_NewGraph();
981
982  // Make a placeholder operation.
983  TF_Operation* feed = Placeholder(graph, s);
984  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
985
986  // Make a constant operation with the scalar "0", for axis.
987  TF_Operation* one = ScalarConst(0, graph, s);
988  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
989
990  // Create a session for this graph.
991  CSession csession(graph, s, use_XLA);
992  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
993
994  if (!device.empty()) {
995    LOG(INFO) << "Setting op Min on device " << device;
996  }
997  TF_Operation* min = MinWithDevice(feed, one, graph, device, s);
998  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
999
1000  // Run the graph.
1001  csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}});
1002  csession.SetOutputs({min});
1003  csession.Run(s);
1004  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1005  TF_Tensor* out = csession.output_tensor(0);
1006  ASSERT_TRUE(out != nullptr);
1007  EXPECT_EQ(TF_INT32, TF_TensorType(out));
1008  EXPECT_EQ(0, TF_NumDims(out));  // scalar
1009  ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1010  int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1011  EXPECT_EQ(2, *output_contents);
1012
1013  // Clean up
1014  csession.CloseAndDelete(s);
1015  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1016  TF_DeleteGraph(graph);
1017  TF_DeleteStatus(s);
1018}
1019
1020TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); }
1021
1022TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); }
1023
1024TEST(CAPI, Session_Min_GPU) {
1025  const string gpu_device = GPUDeviceName();
1026  // Skip this test if no GPU is available.
1027  if (gpu_device.empty()) return;
1028
1029  RunMinTest(gpu_device, /*use_XLA=*/false);
1030}
1031
1032TEST(CAPI, Session_Min_XLA_GPU) {
1033  const string gpu_device = GPUDeviceName();
1034  // Skip this test if no GPU is available.
1035  if (gpu_device.empty()) return;
1036
1037  RunMinTest(gpu_device, /*use_XLA=*/true);
1038}
1039
1040TEST(CAPI, SessionPRun) {
1041  TF_Status* s = TF_NewStatus();
1042  TF_Graph* graph = TF_NewGraph();
1043
1044  // Construct the graph: A + 2 + B
1045  TF_Operation* a = Placeholder(graph, s, "A");
1046  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1047
1048  TF_Operation* b = Placeholder(graph, s, "B");
1049  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1050
1051  TF_Operation* two = ScalarConst(2, graph, s);
1052  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1053
1054  TF_Operation* plus2 = Add(a, two, graph, s, "plus2");
1055  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1056
1057  TF_Operation* plusB = Add(plus2, b, graph, s, "plusB");
1058  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1059
1060  // Setup a session and a partial run handle.  The partial run will allow
1061  // computation of A + 2 + B in two phases (calls to TF_SessionPRun):
1062  // 1. Feed A and get (A+2)
1063  // 2. Feed B and get (A+2)+B
1064  TF_SessionOptions* opts = TF_NewSessionOptions();
1065  TF_Session* sess = TF_NewSession(graph, opts, s);
1066  TF_DeleteSessionOptions(opts);
1067
1068  TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}};
1069  TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}};
1070
1071  const char* handle = nullptr;
1072  TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
1073                      TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s);
1074  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1075
1076  // Feed A and fetch A + 2.
1077  TF_Output feeds1[] = {TF_Output{a, 0}};
1078  TF_Output fetches1[] = {TF_Output{plus2, 0}};
1079  TF_Tensor* feedValues1[] = {Int32Tensor(1)};
1080  TF_Tensor* fetchValues1[1];
1081  TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1,
1082                 1, nullptr, 0, s);
1083  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1084  EXPECT_EQ(3, *(static_cast<int32*>(TF_TensorData(fetchValues1[0]))));
1085  TF_DeleteTensor(feedValues1[0]);
1086  TF_DeleteTensor(fetchValues1[0]);
1087
1088  // Feed B and fetch (A + 2) + B.
1089  TF_Output feeds2[] = {TF_Output{b, 0}};
1090  TF_Output fetches2[] = {TF_Output{plusB, 0}};
1091  TF_Tensor* feedValues2[] = {Int32Tensor(4)};
1092  TF_Tensor* fetchValues2[1];
1093  TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2,
1094                 1, nullptr, 0, s);
1095  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1096  EXPECT_EQ(7, *(static_cast<int32*>(TF_TensorData(fetchValues2[0]))));
1097  TF_DeleteTensor(feedValues2[0]);
1098  TF_DeleteTensor(fetchValues2[0]);
1099
1100  // Clean up.
1101  TF_DeletePRunHandle(handle);
1102  TF_DeleteSession(sess, s);
1103  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1104  TF_DeleteGraph(graph);
1105  TF_DeleteStatus(s);
1106}
1107
1108TEST(CAPI, ShapeInferenceError) {
1109  // TF_FinishOperation should fail if the shape of the added operation cannot
1110  // be inferred.
1111  TF_Status* status = TF_NewStatus();
1112  TF_Graph* graph = TF_NewGraph();
1113
1114  // Create this failure by trying to add two nodes with incompatible shapes
1115  // (A tensor with shape [2] and a tensor with shape [3] cannot be added).
1116  const char data[] = {1, 2, 3};
1117  const int64_t vec2_dims[] = {2};
1118  unique_tensor_ptr vec2_tensor(
1119      Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor);
1120  TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2");
1121  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1122
1123  const int64_t vec3_dims[] = {3};
1124  unique_tensor_ptr vec3_tensor(
1125      Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor);
1126  TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
1127  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1128
1129  TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
1130  ASSERT_NE(TF_OK, TF_GetCode(status));
1131  ASSERT_TRUE(add == nullptr);
1132
1133  TF_DeleteGraph(graph);
1134  TF_DeleteStatus(status);
1135}
1136
1137TEST(CAPI, GetOpDef) {
1138  TF_Status* status = TF_NewStatus();
1139  TF_Graph* graph = TF_NewGraph();
1140  TF_Buffer* buffer = TF_NewBuffer();
1141
1142  TF_GraphGetOpDef(graph, "Add", buffer, status);
1143  ASSERT_EQ(TF_OK, TF_GetCode(status));
1144  const OpDef* expected_op_def;
1145  TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
1146  string expected_serialized;
1147  expected_op_def->SerializeToString(&expected_serialized);
1148  string actual_string(reinterpret_cast<const char*>(buffer->data),
1149                       buffer->length);
1150  EXPECT_EQ(expected_serialized, actual_string);
1151
1152  TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status);
1153  EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status));
1154  ExpectHasSubstr(TF_Message(status),
1155                  "Op type not registered 'MyFakeOp' in binary");
1156
1157  TF_DeleteBuffer(buffer);
1158  TF_DeleteGraph(graph);
1159  TF_DeleteStatus(status);
1160}
1161
1162void StringVectorToArrays(const std::vector<string>& v,
1163                          std::unique_ptr<const void* []>* ptrs,
1164                          std::unique_ptr<size_t[]>* lens) {
1165  ptrs->reset(new const void*[v.size()]);
1166  lens->reset(new size_t[v.size()]);
1167  for (size_t i = 0; i < v.size(); ++i) {
1168    (*ptrs)[i] = v[i].data();
1169    (*lens)[i] = v[i].size();
1170  }
1171}
1172
1173class CApiColocationTest : public ::testing::Test {
1174 protected:
1175  CApiColocationTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
1176
1177  void SetUp() override {
1178    feed1_ = Placeholder(graph_, s_, "feed1");
1179    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1180
1181    feed2_ = Placeholder(graph_, s_, "feed2");
1182    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1183
1184    constant_ = ScalarConst(10, graph_, s_);
1185    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1186
1187    desc_ = TF_NewOperation(graph_, "AddN", "add");
1188    TF_Output inputs[] = {{feed1_, 0}, {constant_, 0}};
1189    TF_AddInputList(desc_, inputs, TF_ARRAYSIZE(inputs));
1190  }
1191
1192  ~CApiColocationTest() override {
1193    TF_DeleteGraph(graph_);
1194    TF_DeleteStatus(s_);
1195  }
1196
1197  void SetViaStringList(TF_OperationDescription* desc,
1198                        const std::vector<string>& list) {
1199    std::unique_ptr<const void* []> list_ptrs;
1200    std::unique_ptr<size_t[]> list_lens;
1201    StringVectorToArrays(list, &list_ptrs, &list_lens);
1202    TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
1203                         list_lens.get(), list.size());
1204  }
1205
1206  void SetViaProto(TF_OperationDescription* desc,
1207                   const std::vector<string>& list) {
1208    tensorflow::AttrValue attr;
1209    for (const string& v : list) {
1210      attr.mutable_list()->add_s(v);
1211    }
1212    string bytes;
1213    attr.SerializeToString(&bytes);
1214    TF_SetAttrValueProto(desc, tensorflow::kColocationAttrName, bytes.data(),
1215                         bytes.size(), s_);
1216    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1217  }
1218
1219  void VerifyCollocation(TF_Operation* op,
1220                         const std::vector<string>& expected) {
1221    TF_AttrMetadata m =
1222        TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
1223    if (expected.empty()) {
1224      ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1225      EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."),
1226                std::string(TF_Message(s_)));
1227      return;
1228    }
1229    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1230    EXPECT_EQ(1, m.is_list);
1231    EXPECT_EQ(expected.size(), m.list_size);
1232    EXPECT_EQ(TF_ATTR_STRING, m.type);
1233    std::vector<void*> values(expected.size());
1234    std::vector<size_t> lens(expected.size());
1235    std::unique_ptr<char[]> storage(new char[m.total_size]);
1236    TF_OperationGetAttrStringList(op, tensorflow::kColocationAttrName,
1237                                  values.data(), lens.data(), expected.size(),
1238                                  storage.get(), m.total_size, s_);
1239    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1240    for (int i = 0; i < expected.size(); ++i) {
1241      EXPECT_EQ(expected[i],
1242                string(static_cast<const char*>(values[i]), lens[i]));
1243    }
1244  }
1245
1246  void FinishAndVerify(TF_OperationDescription* desc,
1247                       const std::vector<string>& expected) {
1248    TF_Operation* op = TF_FinishOperation(desc_, s_);
1249    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1250    VerifyCollocation(op, expected);
1251  }
1252
1253  TF_Status* s_;
1254  TF_Graph* graph_;
1255  TF_Operation* feed1_;
1256  TF_Operation* feed2_;
1257  TF_Operation* constant_;
1258  TF_OperationDescription* desc_;
1259};
1260
1261TEST_F(CApiColocationTest, ColocateWith) {
1262  TF_ColocateWith(desc_, feed1_);
1263  FinishAndVerify(desc_, {"loc:@feed1"});
1264}
1265
1266TEST_F(CApiColocationTest, StringList) {
1267  SetViaStringList(desc_, {"loc:@feed1"});
1268  FinishAndVerify(desc_, {"loc:@feed1"});
1269}
1270
1271TEST_F(CApiColocationTest, Proto) {
1272  SetViaProto(desc_, {"loc:@feed1"});
1273  FinishAndVerify(desc_, {"loc:@feed1"});
1274}
1275
1276TEST_F(CApiColocationTest, ColocateWith_StringList) {
1277  TF_ColocateWith(desc_, feed1_);
1278  SetViaStringList(desc_, {"loc:@feed2"});
1279  FinishAndVerify(desc_, {"loc:@feed2"});
1280}
1281
1282TEST_F(CApiColocationTest, ColocateWith_Proto) {
1283  TF_ColocateWith(desc_, feed1_);
1284  SetViaProto(desc_, {"loc:@feed2"});
1285  FinishAndVerify(desc_, {"loc:@feed2"});
1286}
1287
1288TEST_F(CApiColocationTest, StringList_ColocateWith) {
1289  SetViaStringList(desc_, {"loc:@feed2"});
1290  TF_ColocateWith(desc_, feed1_);
1291  FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1292}
1293
1294TEST_F(CApiColocationTest, Proto_ColocateWith) {
1295  SetViaProto(desc_, {"loc:@feed2"});
1296  TF_ColocateWith(desc_, feed1_);
1297  FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1298}
1299
1300TEST_F(CApiColocationTest, ColocateWith_ColocateWith) {
1301  TF_ColocateWith(desc_, feed1_);
1302  TF_ColocateWith(desc_, feed2_);
1303  FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1304}
1305
1306TEST_F(CApiColocationTest, Proto_StringList) {
1307  SetViaProto(desc_, {"loc:@feed1"});
1308  SetViaStringList(desc_, {"loc:@feed2"});
1309  FinishAndVerify(desc_, {"loc:@feed2"});
1310}
1311
1312TEST_F(CApiColocationTest, StringList_Proto) {
1313  SetViaStringList(desc_, {"loc:@feed1"});
1314  SetViaProto(desc_, {"loc:@feed2"});
1315  FinishAndVerify(desc_, {"loc:@feed2"});
1316}
1317
1318TEST_F(CApiColocationTest, ClearViaStringList) {
1319  TF_ColocateWith(desc_, feed1_);
1320  SetViaStringList(desc_, {});
1321  FinishAndVerify(desc_, {});
1322}
1323
1324TEST_F(CApiColocationTest, ClearViaProto) {
1325  TF_ColocateWith(desc_, feed1_);
1326  SetViaProto(desc_, {});
1327  FinishAndVerify(desc_, {});
1328}
1329
1330TEST(CAPI, SavedModel) {
1331  // Load the saved model.
1332  const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
1333  const string saved_model_dir = tensorflow::io::JoinPath(
1334      tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
1335  TF_SessionOptions* opt = TF_NewSessionOptions();
1336  TF_Buffer* run_options = TF_NewBufferFromString("", 0);
1337  TF_Buffer* metagraph = TF_NewBuffer();
1338  TF_Status* s = TF_NewStatus();
1339  const char* tags[] = {tensorflow::kSavedModelTagServe};
1340  TF_Graph* graph = TF_NewGraph();
1341  TF_Session* session = TF_LoadSessionFromSavedModel(
1342      opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
1343  TF_DeleteBuffer(run_options);
1344  TF_DeleteSessionOptions(opt);
1345  tensorflow::MetaGraphDef metagraph_def;
1346  metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
1347  TF_DeleteBuffer(metagraph);
1348
1349  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1350  CSession csession(session);
1351
1352  // Retrieve the regression signature from meta graph def.
1353  const auto signature_def_map = metagraph_def.signature_def();
1354  const auto signature_def = signature_def_map.at("regress_x_to_y");
1355
1356  const string input_name =
1357      signature_def.inputs().at(tensorflow::kRegressInputs).name();
1358  const string output_name =
1359      signature_def.outputs().at(tensorflow::kRegressOutputs).name();
1360
1361  // Write {0, 1, 2, 3} as tensorflow::Example inputs.
1362  Tensor input(tensorflow::DT_STRING, TensorShape({4}));
1363  for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
1364    tensorflow::Example example;
1365    auto* feature_map = example.mutable_features()->mutable_feature();
1366    (*feature_map)["x"].mutable_float_list()->add_value(i);
1367    input.flat<string>()(i) = example.SerializeAsString();
1368  }
1369
1370  const tensorflow::string input_op_name =
1371      tensorflow::ParseTensorName(input_name).first.ToString();
1372  TF_Operation* input_op =
1373      TF_GraphOperationByName(graph, input_op_name.c_str());
1374  ASSERT_TRUE(input_op != nullptr);
1375  csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
1376  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1377
1378  const tensorflow::string output_op_name =
1379      tensorflow::ParseTensorName(output_name).first.ToString();
1380  TF_Operation* output_op =
1381      TF_GraphOperationByName(graph, output_op_name.c_str());
1382  ASSERT_TRUE(output_op != nullptr);
1383  csession.SetOutputs({output_op});
1384  csession.Run(s);
1385  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1386
1387  TF_Tensor* out = csession.output_tensor(0);
1388  ASSERT_TRUE(out != nullptr);
1389  EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
1390  EXPECT_EQ(2, TF_NumDims(out));
1391  EXPECT_EQ(4, TF_Dim(out, 0));
1392  EXPECT_EQ(1, TF_Dim(out, 1));
1393  float* values = static_cast<float*>(TF_TensorData(out));
1394  // These values are defined to be (input / 2) + 2.
1395  EXPECT_EQ(2, values[0]);
1396  EXPECT_EQ(2.5, values[1]);
1397  EXPECT_EQ(3, values[2]);
1398  EXPECT_EQ(3.5, values[3]);
1399
1400  csession.CloseAndDelete(s);
1401  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1402  TF_DeleteGraph(graph);
1403  TF_DeleteStatus(s);
1404}
1405
1406TEST(CAPI, SavedModelNullArgsAreValid) {
1407  const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
1408  const string saved_model_dir = tensorflow::io::JoinPath(
1409      tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
1410  TF_SessionOptions* opt = TF_NewSessionOptions();
1411  TF_Status* s = TF_NewStatus();
1412  const char* tags[] = {tensorflow::kSavedModelTagServe};
1413  TF_Graph* graph = TF_NewGraph();
1414  // NULL run_options and meta_graph_def should work.
1415  TF_Session* session = TF_LoadSessionFromSavedModel(
1416      opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s);
1417  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1418  TF_DeleteSessionOptions(opt);
1419  TF_CloseSession(session, s);
1420  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1421  TF_DeleteSession(session, s);
1422  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1423  TF_DeleteGraph(graph);
1424  TF_DeleteStatus(s);
1425}
1426
1427REGISTER_OP("TestOpWithNoGradient")
1428    .Input("x: T")
1429    .Output("y: T")
1430    .Attr("T: {float, double}")
1431    .Doc(R"doc(
1432Test op with no grad registered.
1433
1434x: input
1435y: output
1436)doc")
1437    .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1438
1439class CApiGradientsTest : public ::testing::Test {
1440 protected:
1441  CApiGradientsTest()
1442      : s_(TF_NewStatus()),
1443        graph_(TF_NewGraph()),
1444        expected_graph_(TF_NewGraph()) {}
1445
1446  ~CApiGradientsTest() override {
1447    TF_DeleteGraph(graph_);
1448    TF_DeleteGraph(expected_graph_);
1449    TF_DeleteStatus(s_);
1450  }
1451
1452  void TestGradientsSuccess(bool grad_inputs_provided) {
1453    TF_Output inputs[2];
1454    TF_Output outputs[1];
1455    TF_Output grad_outputs[2];
1456    TF_Output expected_grad_outputs[2];
1457
1458    BuildSuccessGraph(inputs, outputs);
1459    BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
1460
1461    AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs);
1462
1463    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1464
1465    // Compare that the graphs match.
1466    GraphDef expected_gdef;
1467    GraphDef gdef;
1468    EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef));
1469    EXPECT_TRUE(GetGraphDef(graph_, &gdef));
1470    TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
1471
1472    // Compare that the output of the gradients of both graphs match.
1473    RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
1474  }
1475
1476  void TestGradientsError(bool grad_inputs_provided) {
1477    TF_Output inputs[1];
1478    TF_Output outputs[1];
1479    TF_Output grad_outputs[1];
1480
1481    BuildErrorGraph(inputs, outputs);
1482
1483    AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs);
1484
1485    string expected_msg =
1486        "No gradient defined for op: TestOpWithNoGradient. Please see "
1487        "https://www.tensorflow.org/code/"
1488        "tensorflow/cc/gradients/README.md"
1489        " for instructions on how to add C++ gradients.";
1490    EXPECT_EQ(expected_msg, TF_Message(s_));
1491  }
1492
1493  // Run the graph and ensure that the gradient values are as expected.
1494  void RunGraphsAndCompareOutputs(TF_Output* grad_outputs,
1495                                  TF_Output* expected_grad_outputs) {
1496    std::unique_ptr<CSession> csession(new CSession(graph_, s_));
1497    std::unique_ptr<CSession> expected_csession(
1498        new CSession(expected_graph_, s_));
1499
1500    std::vector<TF_Output> grad_outputs_vec;
1501    grad_outputs_vec.assign(grad_outputs, grad_outputs + 2);
1502    csession->SetOutputs(grad_outputs_vec);
1503    csession->Run(s_);
1504    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1505    TF_Tensor* out0 = csession->output_tensor(0);
1506    TF_Tensor* out1 = csession->output_tensor(1);
1507
1508    std::vector<TF_Output> expected_grad_outputs_vec;
1509    expected_grad_outputs_vec.assign(expected_grad_outputs,
1510                                     expected_grad_outputs + 2);
1511    expected_csession->SetOutputs(expected_grad_outputs_vec);
1512    expected_csession->Run(s_);
1513    ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1514    TF_Tensor* expected_out0 = expected_csession->output_tensor(0);
1515    TF_Tensor* expected_out1 = expected_csession->output_tensor(1);
1516
1517    CompareTensors(out0, expected_out0);
1518    CompareTensors(out1, expected_out1);
1519  }
1520
1521  void CompareTensors(TF_Tensor* a, TF_Tensor* b) {
1522    float* a_data = static_cast<float*>(TF_TensorData(a));
1523    float* b_data = static_cast<float*>(TF_TensorData(b));
1524    EXPECT_EQ(*a_data, *b_data);
1525  }
1526
1527  void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs,
1528                    TF_Output* outputs, int noutputs, TF_Output* grad_outputs) {
1529    if (grad_inputs_provided) {
1530      TF_Output grad_inputs[1];
1531      const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
1532      TF_Operation* grad_inputs_op =
1533          FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
1534      grad_inputs[0] = TF_Output{grad_inputs_op, 0};
1535      TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs,
1536                      s_, grad_outputs);
1537    } else {
1538      TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_,
1539                      grad_outputs);
1540    }
1541  }
1542
1543  void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) {
1544    const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1545    TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1546    TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad");
1547    inputs[0] = TF_Output{const0, 0};
1548    outputs[0] = TF_Output{nograd, 0};
1549    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1550  }
1551
1552  void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
1553    // Construct the following graph:
1554    //            |
1555    //           z|
1556    //            |
1557    //          MatMul
1558    //         /       \
1559    //        ^         ^
1560    //        |         |
1561    //       x|        y|
1562    //        |         |
1563    //        |         |
1564    //      Const_0    Const_1
1565    //
1566    const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1567    const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1568    TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1569    TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
1570    TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul");
1571    inputs[0] = TF_Output{const0, 0};
1572    inputs[1] = TF_Output{const1, 0};
1573    outputs[0] = TF_Output{matmul, 0};
1574    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1575  }
1576
1577  void BuildExpectedGraph(bool grad_inputs_provided,
1578                          TF_Output* expected_grad_outputs) {
1579    // The expected graph looks like this if grad_inputs_provided.
1580    // If grad_inputs_provided is false, Const_0 will be a OnesLike op.
1581    //      ^             ^
1582    //    dy|           dx|        // MatMul Gradient Graph
1583    //      |             |
1584    //   MatMul_2      MatMul_1
1585    //   ^   ^          ^    ^
1586    //   |   |----------|    |
1587    //   |        ^          |
1588    //   |      dz|          |
1589    //   |        |          |
1590    //   |     Const_3       |
1591    //   |                   |
1592    //   |        ^          |
1593    //   |       z|          |     // MatMul Forward Graph
1594    //   |        |          |
1595    //   |      MatMul       |
1596    //   |     /       \     |
1597    //   |    ^         ^    |
1598    //   |    |         |    |
1599    //   |---x|        y|----|
1600    //        |         |
1601    //        |         |
1602    //      Const_0   Const_1
1603    //
1604    const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1605    const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1606    TF_Operation* const0 =
1607        FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
1608    TF_Operation* const1 =
1609        FloatConst2x2(expected_graph_, s_, const1_val, "Const_1");
1610    TF_Operation* matmul =
1611        MatMul(expected_graph_, s_, const0, const1, "MatMul");
1612
1613    TF_Operation* const3;
1614    if (grad_inputs_provided) {
1615      const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
1616      const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
1617    } else {
1618      const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike");
1619    }
1620
1621    TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1,
1622                                   "gradients/MatMul", false, true);
1623    TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3,
1624                                   "gradients/MatMul_1", true, false);
1625    expected_grad_outputs[0] = {matmul1, 0};
1626    expected_grad_outputs[1] = {matmul2, 0};
1627  }
1628
1629  TF_Tensor* FloatTensor2x2(const float* values) {
1630    const int64_t dims[2] = {2, 2};
1631    TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
1632    memcpy(TF_TensorData(t), values, sizeof(float) * 4);
1633    return t;
1634  }
1635
1636  TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s,
1637                              const float* values, const char* name) {
1638    unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor);
1639    TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
1640    TF_SetAttrTensor(desc, "value", tensor.get(), s);
1641    if (TF_GetCode(s) != TF_OK) return nullptr;
1642    TF_SetAttrType(desc, "dtype", TF_FLOAT);
1643    TF_Operation* op = TF_FinishOperation(desc, s);
1644    EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1645    return op;
1646  }
1647
1648  TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l,
1649                       TF_Operation* r, const char* name,
1650                       bool transpose_a = false, bool transpose_b = false) {
1651    TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
1652    if (transpose_a) {
1653      TF_SetAttrBool(desc, "transpose_a", 1);
1654    }
1655    if (transpose_b) {
1656      TF_SetAttrBool(desc, "transpose_b", 1);
1657    }
1658    TF_AddInput(desc, {l, 0});
1659    TF_AddInput(desc, {r, 0});
1660    TF_Operation* op = TF_FinishOperation(desc, s);
1661    EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1662    return op;
1663  }
1664
1665  TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1666                         const char* name) {
1667    TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name);
1668    TF_AddInput(desc, {in, 0});
1669    TF_Operation* op = TF_FinishOperation(desc, s);
1670    EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1671    return op;
1672  }
1673
1674  TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1675                             const char* name) {
1676    TF_OperationDescription* desc =
1677        TF_NewOperation(graph, "TestOpWithNoGradient", name);
1678    TF_AddInput(desc, {in, 0});
1679    TF_Operation* op = TF_FinishOperation(desc, s);
1680    EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1681    return op;
1682  }
1683
1684  TF_Status* s_;
1685  TF_Graph* graph_;
1686  TF_Graph* expected_graph_;
1687};
1688
1689TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); }
1690
1691TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
1692  TestGradientsSuccess(false);
1693}
1694
1695TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
1696  TestGradientsError(true);
1697}
1698
1699TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
1700  TestGradientsError(false);
1701}
1702
1703// REGISTER_OP for CApiTestAttributesTest test cases.
1704// Registers two ops, each with a single attribute called 'v'.
1705// The attribute in one op will have a type 'type', the other
1706// will have list(type).
1707#define ATTR_TEST_REGISTER_OP(type)                           \
1708  REGISTER_OP("CApiAttributesTestOp" #type)                   \
1709      .Attr("v: " #type)                                      \
1710      .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
1711  REGISTER_OP("CApiAttributesTestOpList" #type)               \
1712      .Attr("v: list(" #type ")")                             \
1713      .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1714ATTR_TEST_REGISTER_OP(string);
1715ATTR_TEST_REGISTER_OP(int);
1716ATTR_TEST_REGISTER_OP(float);
1717ATTR_TEST_REGISTER_OP(bool);
1718ATTR_TEST_REGISTER_OP(type);
1719ATTR_TEST_REGISTER_OP(shape);
1720ATTR_TEST_REGISTER_OP(tensor);
1721#undef ATTR_TEST_REGISTER_OP
1722
1723class CApiAttributesTest : public ::testing::Test {
1724 protected:
1725  CApiAttributesTest()
1726      : s_(TF_NewStatus()), graph_(TF_NewGraph()), counter_(0) {}
1727
1728  ~CApiAttributesTest() override {
1729    TF_DeleteGraph(graph_);
1730    TF_DeleteStatus(s_);
1731  }
1732
1733  TF_OperationDescription* init(string type) {
1734    // Construct op_name to match the name used by REGISTER_OP in the
1735    // ATTR_TEST_REGISTER calls above.
1736    string op_name = "CApiAttributesTestOp";
1737    if (type.find("list(") == 0) {
1738      op_name += "List";
1739      type = type.replace(0, 5, "");
1740      type = type.replace(type.size() - 1, 1, "");
1741    }
1742    op_name += type;
1743    return TF_NewOperation(
1744        graph_, op_name.c_str(),
1745        ::tensorflow::strings::StrCat("name", counter_++).c_str());
1746  }
1747
1748  TF_Status* s_;
1749
1750 private:
1751  TF_Graph* graph_;
1752  int counter_;
1753};
1754
1755// Helper macros for the TF_OperationGetAttr* tests.
1756// TODO(ashankar): Use gmock matchers instead?
1757// (https://github.com/google/googletest/blob/master/googlemock/docs/CookBook.md#writing-new-parameterized-matchers-quickly)
1758// That will require setting up the tensorflow build with gmock.
1759#define EXPECT_TF_META(attr_name, expected_list_size, expected_type, \
1760                       expected_total_size)                          \
1761  do {                                                               \
1762    auto m = TF_OperationGetAttrMetadata(oper, attr_name, s_);       \
1763    EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);              \
1764    const unsigned char e = expected_list_size >= 0 ? 1 : 0;         \
1765    EXPECT_EQ(e, m.is_list);                                         \
1766    EXPECT_EQ(expected_list_size, m.list_size);                      \
1767    EXPECT_EQ(expected_type, m.type);                                \
1768    EXPECT_EQ(expected_total_size, m.total_size);                    \
1769  } while (0)
1770
1771TEST_F(CApiAttributesTest, String) {
1772  auto desc = init("string");
1773  TF_SetAttrString(desc, "v", "bunny", 5);
1774
1775  auto oper = TF_FinishOperation(desc, s_);
1776  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1777  EXPECT_TF_META("v", -1, TF_ATTR_STRING, 5);
1778  std::unique_ptr<char[]> value(new char[5]);
1779
1780  TF_OperationGetAttrString(oper, "v", value.get(), 5, s_);
1781  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1782  EXPECT_EQ("bunny", string(static_cast<const char*>(value.get()), 5));
1783}
1784
1785TEST_F(CApiAttributesTest, StringList) {
1786  std::vector<string> list = {"bugs", "bunny", "duck"};
1787  std::unique_ptr<const void* []> list_ptrs;
1788  std::unique_ptr<size_t[]> list_lens;
1789  StringVectorToArrays(list, &list_ptrs, &list_lens);
1790  int list_total_size = 0;
1791  for (const auto& s : list) {
1792    list_total_size += s.size();
1793  }
1794
1795  auto desc = init("list(string)");
1796  TF_SetAttrStringList(desc, "v", list_ptrs.get(), list_lens.get(),
1797                       list.size());
1798
1799  auto oper = TF_FinishOperation(desc, s_);
1800  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1801
1802  EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size);
1803  std::unique_ptr<void* []> values(new void*[list.size()]);
1804  std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
1805  std::unique_ptr<char[]> storage(new char[list_total_size]);
1806  TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(),
1807                                list.size(), storage.get(), list_total_size,
1808                                s_);
1809  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1810  for (size_t i = 0; i < list.size(); ++i) {
1811    EXPECT_EQ(list[i].size(), lens[i]) << i;
1812    EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
1813        << i;
1814  }
1815}
1816
1817TEST_F(CApiAttributesTest, Int) {
1818  auto desc = init("int");
1819  TF_SetAttrInt(desc, "v", 31415);
1820
1821  auto oper = TF_FinishOperation(desc, s_);
1822  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1823  EXPECT_TF_META("v", -1, TF_ATTR_INT, -1);
1824
1825  int64_t value;
1826  TF_OperationGetAttrInt(oper, "v", &value, s_);
1827  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1828  EXPECT_EQ(31415, value);
1829}
1830
1831TEST_F(CApiAttributesTest, IntList) {
1832  const int64_t list[] = {1, 2, 3, 4};
1833  const size_t list_size = TF_ARRAYSIZE(list);
1834
1835  auto desc = init("list(int)");
1836  TF_SetAttrIntList(desc, "v", list, list_size);
1837
1838  auto oper = TF_FinishOperation(desc, s_);
1839  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1840
1841  int64_t values[list_size];
1842  EXPECT_TF_META("v", list_size, TF_ATTR_INT, -1);
1843  TF_OperationGetAttrIntList(oper, "v", values, list_size, s_);
1844  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1845  EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
1846}
1847
1848TEST_F(CApiAttributesTest, Float) {
1849  auto desc = init("float");
1850  TF_SetAttrFloat(desc, "v", 2.718);
1851
1852  auto oper = TF_FinishOperation(desc, s_);
1853  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1854  EXPECT_TF_META("v", -1, TF_ATTR_FLOAT, -1);
1855
1856  float value;
1857  TF_OperationGetAttrFloat(oper, "v", &value, s_);
1858  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1859  EXPECT_FLOAT_EQ(2.718, value);
1860}
1861
1862TEST_F(CApiAttributesTest, FloatList) {
1863  const float list[] = {1.414, 2.718, 3.1415};
1864  const size_t list_size = TF_ARRAYSIZE(list);
1865
1866  auto desc = init("list(float)");
1867  TF_SetAttrFloatList(desc, "v", list, list_size);
1868
1869  auto oper = TF_FinishOperation(desc, s_);
1870  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1871
1872  float values[list_size];
1873  EXPECT_TF_META("v", list_size, TF_ATTR_FLOAT, -1);
1874  TF_OperationGetAttrFloatList(oper, "v", values, list_size, s_);
1875  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1876  EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
1877}
1878
1879TEST_F(CApiAttributesTest, Bool) {
1880  auto desc = init("bool");
1881  TF_SetAttrBool(desc, "v", 1);
1882
1883  auto oper = TF_FinishOperation(desc, s_);
1884  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1885  EXPECT_TF_META("v", -1, TF_ATTR_BOOL, -1);
1886
1887  unsigned char value;
1888  TF_OperationGetAttrBool(oper, "v", &value, s_);
1889  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1890  EXPECT_EQ(1, value);
1891}
1892
1893TEST_F(CApiAttributesTest, BoolList) {
1894  const unsigned char list[] = {0, 1, 1, 0, 0, 1, 1};
1895  const size_t list_size = TF_ARRAYSIZE(list);
1896
1897  auto desc = init("list(bool)");
1898  TF_SetAttrBoolList(desc, "v", list, list_size);
1899
1900  auto oper = TF_FinishOperation(desc, s_);
1901  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1902
1903  unsigned char values[list_size];
1904  EXPECT_TF_META("v", list_size, TF_ATTR_BOOL, -1);
1905  TF_OperationGetAttrBoolList(oper, "v", values, list_size, s_);
1906  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1907  EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
1908}
1909
1910TEST_F(CApiAttributesTest, Type) {
1911  auto desc = init("type");
1912  TF_SetAttrType(desc, "v", TF_COMPLEX128);
1913
1914  auto oper = TF_FinishOperation(desc, s_);
1915  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1916  EXPECT_TF_META("v", -1, TF_ATTR_TYPE, -1);
1917
1918  TF_DataType value;
1919  TF_OperationGetAttrType(oper, "v", &value, s_);
1920  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1921  EXPECT_EQ(TF_COMPLEX128, value);
1922}
1923
1924TEST_F(CApiAttributesTest, TypeList) {
1925  const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
1926  const size_t list_size = TF_ARRAYSIZE(list);
1927
1928  auto desc = init("list(type)");
1929  TF_SetAttrTypeList(desc, "v", list, list_size);
1930
1931  auto oper = TF_FinishOperation(desc, s_);
1932  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1933
1934  TF_DataType values[list_size];
1935  EXPECT_TF_META("v", list_size, TF_ATTR_TYPE, -1);
1936  TF_OperationGetAttrTypeList(oper, "v", values, list_size, s_);
1937  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1938  EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
1939}
1940
1941TEST_F(CApiAttributesTest, Shape) {
1942  // Unknown shape
1943  auto desc = init("shape");
1944  TF_SetAttrShape(desc, "v", nullptr, -1);
1945  auto oper = TF_FinishOperation(desc, s_);
1946  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1947  EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, -1);
1948  TF_OperationGetAttrShape(oper, "v", nullptr, 10, s_);
1949  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1950
1951  // Partially specified shape
1952  const int64_t partial_shape[] = {17, -1};
1953  const size_t sz = TF_ARRAYSIZE(partial_shape);
1954  desc = init("shape");
1955  TF_SetAttrShape(desc, "v", partial_shape, sz);
1956  oper = TF_FinishOperation(desc, s_);
1957  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1958  EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, sz);
1959  int64_t values[sz];
1960  TF_OperationGetAttrShape(oper, "v", values, sz, s_);
1961  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1962  EXPECT_TRUE(
1963      std::equal(std::begin(partial_shape), std::end(partial_shape), values));
1964}
1965
1966TEST_F(CApiAttributesTest, ShapeList) {
1967  const int64_t shape_1[] = {1, 3};
1968  const int64_t shape_2[] = {2, 4, 6};
1969  const int64_t* list[] = {&shape_1[0], &shape_2[0]};
1970  const size_t list_size = TF_ARRAYSIZE(list);
1971  const int ndims[] = {TF_ARRAYSIZE(shape_1), TF_ARRAYSIZE(shape_2)};
1972  const int total_ndims = 5;  // ndims[0] + ndims[1]
1973
1974  auto desc = init("list(shape)");
1975  TF_SetAttrShapeList(desc, "v", list, ndims, list_size);
1976  auto oper = TF_FinishOperation(desc, s_);
1977  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1978
1979  EXPECT_TF_META("v", list_size, TF_ATTR_SHAPE, total_ndims);
1980  int64_t* values[list_size];
1981  int values_ndims[list_size];
1982  int64_t storage[total_ndims];
1983  TF_OperationGetAttrShapeList(oper, "v", values, values_ndims, list_size,
1984                               storage, total_ndims, s_);
1985  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1986  for (size_t i = 0; i < list_size; ++i) {
1987    EXPECT_EQ(ndims[i], values_ndims[i]) << i;
1988    for (int j = 0; j < values_ndims[i]; ++j) {
1989      EXPECT_EQ(list[i][j], values[i][j]) << "(" << i << ", " << j << ")";
1990    }
1991  }
1992}
1993
1994TEST_F(CApiAttributesTest, TensorShapeProto) {
1995  const tensorflow::int64 pts[] = {2, 4, -1, 8};
1996  tensorflow::TensorShapeProto proto;
1997  tensorflow::PartialTensorShape(pts).AsProto(&proto);
1998  string bytes;
1999  proto.SerializeToString(&bytes);
2000
2001  auto desc = init("shape");
2002  TF_SetAttrTensorShapeProto(desc, "v", bytes.data(), bytes.length(), s_);
2003  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2004  auto oper = TF_FinishOperation(desc, s_);
2005  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2006
2007  EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, 4);
2008  TF_Buffer* value = TF_NewBuffer();
2009  TF_OperationGetAttrTensorShapeProto(oper, "v", value, s_);
2010  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2011  EXPECT_EQ(bytes.length(), value->length);
2012  EXPECT_EQ(0, memcmp(bytes.data(), value->data, value->length));
2013  TF_DeleteBuffer(value);
2014}
2015
2016TEST_F(CApiAttributesTest, TensorShapeProtoList) {
2017  string bytes1, bytes2;
2018  tensorflow::TensorShapeProto proto;
2019
2020  const tensorflow::int64 pts1[] = {2, 4, -1, 8};
2021  tensorflow::PartialTensorShape(pts1).AsProto(&proto);
2022  proto.SerializeToString(&bytes1);
2023
2024  const tensorflow::int64 pts2[] = {1, 3, 5, 7};
2025  tensorflow::PartialTensorShape(pts2).AsProto(&proto);
2026  proto.SerializeToString(&bytes2);
2027
2028  std::unique_ptr<const void* []> list_ptrs;
2029  std::unique_ptr<size_t[]> list_lens;
2030  const std::vector<string> list = {bytes1, bytes2};
2031  StringVectorToArrays(list, &list_ptrs, &list_lens);
2032
2033  auto desc = init("list(shape)");
2034  TF_SetAttrTensorShapeProtoList(desc, "v", list_ptrs.get(), list_lens.get(),
2035                                 list.size(), s_);
2036  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2037  auto oper = TF_FinishOperation(desc, s_);
2038  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2039
2040  EXPECT_TF_META("v", 2, TF_ATTR_SHAPE, 8);
2041  TF_Buffer* values[2];
2042  TF_OperationGetAttrTensorShapeProtoList(oper, "v", values, 2, s_);
2043  EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2044  for (int i = 0; i < 2; ++i) {
2045    int le = list_lens[i];
2046    int la = values[i]->length;
2047    const void* e = list_ptrs[i];
2048    const void* a = values[i]->data;
2049    EXPECT_EQ(le, la) << i;
2050    EXPECT_EQ(0, memcmp(e, a, std::min(le, la))) << i;
2051    TF_DeleteBuffer(values[i]);
2052  }
2053}
2054
2055TEST_F(CApiAttributesTest, Tensor) {
2056  const char tensor[] = {5, 7};
2057  const int64_t dims[] = {1, 2};
2058  const size_t ndims = TF_ARRAYSIZE(dims);
2059
2060  auto desc = init("tensor");
2061  unique_tensor_ptr v(Int8Tensor(dims, ndims, tensor), TF_DeleteTensor);
2062  TF_SetAttrTensor(desc, "v", v.get(), s_);
2063  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2064
2065  auto oper = TF_FinishOperation(desc, s_);
2066  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2067
2068  EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2069  TF_Tensor* value;
2070  TF_OperationGetAttrTensor(oper, "v", &value, s_);
2071  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2072  ASSERT_NE(nullptr, value);
2073  EXPECT_EQ(TF_INT8, TF_TensorType(value));
2074  EXPECT_EQ(ndims, TF_NumDims(value));
2075  for (int i = 0; i < TF_NumDims(value); ++i) {
2076    EXPECT_EQ(dims[i], TF_Dim(value, i)) << i;
2077  }
2078  EXPECT_EQ(sizeof(char) * TF_ARRAYSIZE(tensor), TF_TensorByteSize(value));
2079  EXPECT_EQ(0, memcmp(tensor, TF_TensorData(value), TF_TensorByteSize(value)));
2080  TF_DeleteTensor(value);
2081}
2082
2083TEST_F(CApiAttributesTest, StringTensor) {
2084  // Create the string-Tensor "attribute" value.
2085  char encoded[] = {
2086      0,   0, 0, 0, 0, 0, 0, 0,  // array[uint64] offsets
2087      1,                         // varint encoded string length
2088      'A',
2089  };
2090  auto deallocator = [](void* data, size_t len, void* arg) {};
2091  unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0],
2092                                      sizeof(encoded), deallocator, nullptr),
2093                         TF_DeleteTensor);
2094
2095  // Create a TF_Operation with the attribute t_in
2096  auto desc = init("tensor");
2097  TF_SetAttrTensor(desc, "v", t_in.get(), s_);
2098  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2099
2100  auto oper = TF_FinishOperation(desc, s_);
2101  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2102
2103  // Fetch the attribute back.
2104  EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2105  TF_Tensor* t_out = nullptr;
2106  TF_OperationGetAttrTensor(oper, "v", &t_out, s_);
2107  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2108  EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
2109  EXPECT_EQ(0, TF_NumDims(t_out));
2110  ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
2111  EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out),
2112                      TF_TensorByteSize(t_out)));
2113  TF_DeleteTensor(t_out);
2114}
2115
2116TEST_F(CApiAttributesTest, TensorList) {
2117  const char tensor1[] = {5, 7};
2118  const int64_t dims1[] = {1, 2};
2119  const size_t ndims1 = TF_ARRAYSIZE(dims1);
2120
2121  const char tensor2[] = {2, 4, 6, 8};
2122  const int64_t dims2[] = {2, 2};
2123  const size_t ndims2 = TF_ARRAYSIZE(dims2);
2124
2125  auto desc = init("list(tensor)");
2126  TF_Tensor* tmp[] = {
2127      Int8Tensor(dims1, ndims1, tensor1),
2128      Int8Tensor(dims2, ndims2, tensor2),
2129  };
2130  TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_);
2131  for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) {
2132    TF_DeleteTensor(tmp[i]);
2133  }
2134  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2135  auto oper = TF_FinishOperation(desc, s_);
2136  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2137
2138  EXPECT_TF_META("v", 2, TF_ATTR_TENSOR, -1);
2139  TF_Tensor* values[2];
2140  TF_OperationGetAttrTensorList(oper, "v", &values[0], TF_ARRAYSIZE(values),
2141                                s_);
2142  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2143
2144  const char* tensor_data[] = {&tensor1[0], &tensor2[0]};
2145  const size_t tensor_size[] = {TF_ARRAYSIZE(tensor1), TF_ARRAYSIZE(tensor2)};
2146  const int64_t* tensor_dims[] = {&dims1[0], &dims2[0]};
2147  const size_t tensor_ndims[] = {ndims1, ndims2};
2148  for (int i = 0; i < 2; ++i) {
2149    TF_Tensor* v = values[i];
2150    ASSERT_NE(nullptr, v) << i;
2151    EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i;
2152    EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i;
2153    for (int j = 0; j < TF_NumDims(v); ++j) {
2154      EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j))
2155          << "Tensor #" << i << ", dimension #" << j;
2156    }
2157    EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i;
2158    EXPECT_EQ(0,
2159              memcmp(tensor_data[i], TF_TensorData(v), TF_TensorByteSize(v)));
2160    TF_DeleteTensor(v);
2161  }
2162}
2163
2164TEST_F(CApiAttributesTest, EmptyList) {
2165  auto desc = init("list(int)");
2166  TF_SetAttrIntList(desc, "v", nullptr, 0);
2167  auto oper = TF_FinishOperation(desc, s_);
2168  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2169  EXPECT_TF_META("v", 0, TF_ATTR_INT, -1);
2170}
2171
2172TEST_F(CApiAttributesTest, Errors) {
2173  auto desc = init("int");
2174  TF_SetAttrInt(desc, "v", 3);
2175  auto oper = TF_FinishOperation(desc, s_);
2176  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2177  TF_OperationGetAttrString(oper, "v", nullptr, 0, s_);
2178  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
2179}
2180
2181TEST(TestApiDef, TestCreateApiDef) {
2182  // TODO(b/73318067): Fix linking for the GPU test generated by the
2183  // tf_cuda_cc_test() bazel rule and remove the next line.
2184  if (!GPUDeviceName().empty()) return;
2185
2186  TF_Status* status = TF_NewStatus();
2187  TF_Library* lib =
2188      TF_LoadLibrary("tensorflow/c/test_op.so", status);
2189  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2190  TF_DeleteStatus(status);
2191
2192  TF_Buffer op_list_buf = TF_GetOpList(lib);
2193  status = TF_NewStatus();
2194  auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
2195  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2196  TF_DeleteStatus(status);
2197
2198  string op_name = "TestCApi";
2199  status = TF_NewStatus();
2200  auto* api_def_buf =
2201      TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2202  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2203  TF_DeleteStatus(status);
2204
2205  tensorflow::ApiDef api_def;
2206  EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2207  EXPECT_EQ(op_name, api_def.graph_op_name());
2208  EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
2209
2210  TF_DeleteBuffer(api_def_buf);
2211  TF_DeleteApiDefMap(api_def_map);
2212  TF_DeleteLibraryHandle(lib);
2213}
2214
2215TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
2216  // TODO(b/73318067): Fix linking for the GPU test generated by the
2217  // tf_cuda_cc_test() bazel rule and remove the next line.
2218  if (!GPUDeviceName().empty()) return;
2219
2220  TF_Status* status = TF_NewStatus();
2221  TF_Library* lib =
2222      TF_LoadLibrary("tensorflow/c/test_op.so", status);
2223  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2224  TF_DeleteStatus(status);
2225
2226  TF_Buffer op_list_buf = TF_GetOpList(lib);
2227  status = TF_NewStatus();
2228  auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
2229  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2230  TF_DeleteStatus(status);
2231
2232  string api_def_overwrites = R"(op: <
2233  graph_op_name: "TestCApi"
2234  summary: "New summary"
2235>
2236)";
2237  status = TF_NewStatus();
2238  TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
2239                  api_def_overwrites.size(), status);
2240  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2241  TF_DeleteStatus(status);
2242
2243  string op_name = "TestCApi";
2244  status = TF_NewStatus();
2245  auto* api_def_buf =
2246      TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2247  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2248  TF_DeleteStatus(status);
2249
2250  tensorflow::ApiDef api_def;
2251  EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2252  EXPECT_EQ(op_name, api_def.graph_op_name());
2253  EXPECT_EQ("New summary", api_def.summary());
2254
2255  TF_DeleteBuffer(api_def_buf);
2256  TF_DeleteApiDefMap(api_def_map);
2257  TF_DeleteLibraryHandle(lib);
2258}
2259
2260#undef EXPECT_TF_META
2261
2262}  // namespace
2263}  // namespace tensorflow
2264
2265// TODO(josh11b): Test:
2266// * TF_SetDevice(desc, "/job:worker");
2267// * control inputs / outputs
2268// * targets
2269// * TF_DeleteGraph() before TF_DeleteSession()
2270