1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/c_api.h"
17
18#include <algorithm>
19#include <limits>
20#include <memory>
21#include <vector>
22
23#ifndef __ANDROID__
24#include "tensorflow/cc/framework/gradients.h"
25#include "tensorflow/cc/framework/ops.h"
26#include "tensorflow/cc/framework/scope_internal.h"
27#include "tensorflow/cc/ops/while_loop.h"
28#include "tensorflow/cc/saved_model/loader.h"
29#include "tensorflow/core/framework/op_gen_lib.h"
30#endif
31#include "tensorflow/c/c_api_internal.h"
32#include "tensorflow/core/common_runtime/device_mgr.h"
33#include "tensorflow/core/common_runtime/shape_refiner.h"
34#include "tensorflow/core/framework/allocation_description.pb.h"
35#include "tensorflow/core/framework/log_memory.h"
36#include "tensorflow/core/framework/node_def_util.h"
37#include "tensorflow/core/framework/op_kernel.h"
38#include "tensorflow/core/framework/partial_tensor_shape.h"
39#include "tensorflow/core/framework/tensor.h"
40#include "tensorflow/core/framework/tensor_shape.h"
41#include "tensorflow/core/framework/tensor_shape.pb.h"
42#include "tensorflow/core/framework/types.h"
43#include "tensorflow/core/framework/versions.pb.h"
44#include "tensorflow/core/graph/graph.h"
45#include "tensorflow/core/graph/graph_constructor.h"
46#include "tensorflow/core/graph/node_builder.h"
47#include "tensorflow/core/lib/core/coding.h"
48#include "tensorflow/core/lib/core/errors.h"
49#include "tensorflow/core/lib/core/status.h"
50#include "tensorflow/core/lib/core/stringpiece.h"
51#include "tensorflow/core/lib/gtl/array_slice.h"
52#include "tensorflow/core/lib/strings/strcat.h"
53#include "tensorflow/core/platform/mem.h"
54#include "tensorflow/core/platform/mutex.h"
55#include "tensorflow/core/platform/protobuf.h"
56#include "tensorflow/core/platform/thread_annotations.h"
57#include "tensorflow/core/platform/types.h"
58#include "tensorflow/core/public/session.h"
59#include "tensorflow/core/public/version.h"
60
61// The implementation below is at the top level instead of the
62// brain namespace because we are defining 'extern "C"' functions.
63using tensorflow::AllocationDescription;
64using tensorflow::DataType;
65using tensorflow::Graph;
66using tensorflow::GraphDef;
67using tensorflow::mutex_lock;
68using tensorflow::NameRangeMap;
69using tensorflow::NameRangesForNode;
70using tensorflow::NewSession;
71using tensorflow::Node;
72using tensorflow::NodeBuilder;
73using tensorflow::NodeDef;
74using tensorflow::OpDef;
75using tensorflow::OpRegistry;
76using tensorflow::PartialTensorShape;
77using tensorflow::RunMetadata;
78using tensorflow::RunOptions;
79using tensorflow::Session;
80using tensorflow::Status;
81using tensorflow::string;
82using tensorflow::Tensor;
83using tensorflow::TensorBuffer;
84using tensorflow::TensorId;
85using tensorflow::TensorShape;
86using tensorflow::TensorShapeProto;
87using tensorflow::VersionDef;
88using tensorflow::error::Code;
89using tensorflow::errors::FailedPrecondition;
90using tensorflow::errors::InvalidArgument;
91using tensorflow::gtl::ArraySlice;
92using tensorflow::strings::StrCat;
93
94extern "C" {
95
96// --------------------------------------------------------------------------
97const char* TF_Version() { return TF_VERSION_STRING; }
98
99// --------------------------------------------------------------------------
100size_t TF_DataTypeSize(TF_DataType dt) {
101  return static_cast<size_t>(
102      tensorflow::DataTypeSize(static_cast<DataType>(dt)));
103}
104
105// --------------------------------------------------------------------------
106
107TF_Status* TF_NewStatus() { return new TF_Status; }
108
109void TF_DeleteStatus(TF_Status* s) { delete s; }
110
111void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
112  if (code == TF_OK) {
113    s->status = Status::OK();
114    return;
115  }
116  s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
117}
118
119TF_Code TF_GetCode(const TF_Status* s) {
120  return static_cast<TF_Code>(s->status.code());
121}
122
123const char* TF_Message(const TF_Status* s) {
124  return s->status.error_message().c_str();
125}
126
127// --------------------------------------------------------------------------
128
129namespace {
130class TF_ManagedBuffer : public TensorBuffer {
131 public:
132  void* data_;
133  size_t len_;
134  void (*deallocator_)(void* data, size_t len, void* arg);
135  void* deallocator_arg_;
136
137  ~TF_ManagedBuffer() override {
138    (*deallocator_)(data_, len_, deallocator_arg_);
139  }
140
141  void* data() const override { return data_; }
142  size_t size() const override { return len_; }
143  TensorBuffer* root_buffer() override { return this; }
144  void FillAllocationDescription(AllocationDescription* proto) const override {
145    tensorflow::int64 rb = size();
146    proto->set_requested_bytes(rb);
147    proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
148  }
149
150  // Prevents input forwarding from mutating this buffer.
151  bool OwnsMemory() const override { return false; }
152};
153
154void* allocate_tensor(const char* operation, size_t len) {
155  void* data =
156      tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
157  if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
158    tensorflow::LogMemory::RecordRawAllocation(
159        operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
160        len, data, tensorflow::cpu_allocator());
161  }
162  return data;
163}
164
165void deallocate_buffer(void* data, size_t len, void* arg) {
166  if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
167    tensorflow::LogMemory::RecordRawDeallocation(
168        "TensorFlow C Api",
169        tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
170        tensorflow::cpu_allocator(), false);
171  }
172  tensorflow::cpu_allocator()->DeallocateRaw(data);
173}
174
175}  // namespace
176
177TF_Tensor::~TF_Tensor() { buffer->Unref(); }
178
179TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
180                             int num_dims, size_t len) {
181  void* data = allocate_tensor("TF_AllocateTensor", len);
182  return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
183                      nullptr);
184}
185
186TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
187                        void* data, size_t len,
188                        void (*deallocator)(void* data, size_t len, void* arg),
189                        void* deallocator_arg) {
190  std::vector<tensorflow::int64> dimvec(num_dims);
191  for (int i = 0; i < num_dims; ++i) {
192    dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
193  }
194
195  TF_ManagedBuffer* buf = new TF_ManagedBuffer;
196  buf->len_ = len;
197  if (dtype != TF_STRING && dtype != TF_RESOURCE &&
198      tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
199      reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
200    // TF_STRING and TF_RESOURCE tensors have a different representation in
201    // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
202    // (any alignment requirements will be taken care of by TF_TensorToTensor
203    // and TF_TensorFromTensor).
204    //
205    // Other types have the same representation, so copy only if it is safe to
206    // do so.
207    buf->data_ = allocate_tensor("TF_NewTensor", len);
208    std::memcpy(buf->data_, data, len);
209    buf->deallocator_ = deallocate_buffer;
210    buf->deallocator_arg_ = nullptr;
211    // Free the original buffer.
212    deallocator(data, len, deallocator_arg);
213  } else {
214    buf->data_ = data;
215    buf->deallocator_ = deallocator;
216    buf->deallocator_arg_ = deallocator_arg;
217  }
218  TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
219  size_t elem_size = TF_DataTypeSize(dtype);
220  if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
221    delete ret;
222    return nullptr;
223  }
224  return ret;
225}
226
227TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
228  // It is safe to move the Tensor if and only if we own the unique reference to
229  // it. In that case, we might as well not delete and reallocate, but a future
230  // implementation might need to do so.
231  TensorBuffer* buf = tensor->buffer;
232  if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
233      buf->OwnsMemory()) {
234    return tensor;
235  }
236  return nullptr;
237}
238
239void TF_DeleteTensor(TF_Tensor* t) { delete t; }
240
241TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
242int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
243int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
244  return static_cast<int64_t>(t->shape.dim_size(dim_index));
245}
246size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
247void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
248
249// --------------------------------------------------------------------------
250size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
251                       size_t dst_len, TF_Status* status) {
252  const size_t sz = TF_StringEncodedSize(src_len);
253  if (sz < src_len) {
254    status->status = InvalidArgument("src string is too large to encode");
255    return 0;
256  }
257  if (dst_len < sz) {
258    status->status =
259        InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
260                        src_len, "-byte string");
261    return 0;
262  }
263  dst = tensorflow::core::EncodeVarint64(dst, src_len);
264  memcpy(dst, src, src_len);
265  return sz;
266}
267
268static Status TF_StringDecode_Impl(const char* src, size_t src_len,
269                                   const char** dst, size_t* dst_len) {
270  tensorflow::uint64 len64 = 0;
271  const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
272  if (p == nullptr) {
273    return InvalidArgument("invalid string encoding or truncated src buffer");
274  }
275  if (len64 > std::numeric_limits<size_t>::max()) {
276    return InvalidArgument("encoded string is ", len64,
277                           "-bytes, which is too large for this architecture");
278  }
279  *dst = p;
280  *dst_len = static_cast<size_t>(len64);
281  return Status::OK();
282}
283
284size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
285                       size_t* dst_len, TF_Status* status) {
286  status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
287  if (!status->status.ok()) return 0;
288  return static_cast<size_t>(*dst - src) + *dst_len;
289}
290
291size_t TF_StringEncodedSize(size_t len) {
292  return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
293}
294
295// --------------------------------------------------------------------------
296TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
297void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
298
299void TF_SetTarget(TF_SessionOptions* options, const char* target) {
300  options->options.target = target;
301}
302
303void TF_SetConfig(TF_SessionOptions* options, const void* proto,
304                  size_t proto_len, TF_Status* status) {
305  if (!options->options.config.ParseFromArray(proto, proto_len)) {
306    status->status = InvalidArgument("Unparseable ConfigProto");
307  }
308}
309// --------------------------------------------------------------------------
310TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
311
312TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
313  void* copy = tensorflow::port::Malloc(proto_len);
314  memcpy(copy, proto, proto_len);
315
316  TF_Buffer* buf = new TF_Buffer;
317  buf->data = copy;
318  buf->length = proto_len;
319  buf->data_deallocator = [](void* data, size_t length) {
320    tensorflow::port::Free(data);
321  };
322  return buf;
323}
324
325void TF_DeleteBuffer(TF_Buffer* buffer) {
326  if (buffer->data_deallocator != nullptr) {
327    (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
328                                buffer->length);
329  }
330  delete buffer;
331}
332
333TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
334
335// --------------------------------------------------------------------------
336
337TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
338                                              TF_Status* status) {
339  Session* session;
340  status->status = NewSession(opt->options, &session);
341  if (status->status.ok()) {
342    return new TF_DeprecatedSession({session});
343  } else {
344    DCHECK_EQ(nullptr, session);
345    return nullptr;
346  }
347}
348
349void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
350  status->status = s->session->Close();
351}
352
353void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
354  status->status = Status::OK();
355  delete s->session;
356  delete s;
357}
358
359void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
360                    size_t proto_len, TF_Status* status) {
361  GraphDef g;
362  if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
363    status->status = InvalidArgument("Invalid GraphDef");
364    return;
365  }
366  status->status = s->session->Extend(g);
367}
368
369static void DeleteArray(void* data, size_t size, void* arg) {
370  DCHECK_EQ(data, arg);
371  delete[] reinterpret_cast<char*>(arg);
372}
373
374}  // end extern "C"
375
376namespace tensorflow {
377namespace {
378
379// Reset helper for converting character arrays to string vectors.
380void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
381                     int ncontainers, TF_Status* status) {
382  std::vector<string> container_names(ncontainers);
383  for (int i = 0; i < ncontainers; ++i) {
384    container_names[i] = containers[i];
385  }
386
387  status->status = Reset(opt->options, container_names);
388}
389
390// This traverses the specified nodes in topological order to verify there are
391// no cycles. Starting with inputless nodes, it visits nodes whose inputs have
392// all been visited, and counts the total number of visited nodes. If there is a
393// cycle, nodes in the cycle will never be visited, and the visited count will
394// be less than the total node count.
395Status ValidateNoCycles(const Graph& g) {
396  // TODO(nolivia): check this on a subset of the graph instead of all of it.
397  // A node is ready when all of its inputs have been visited.
398  std::vector<const Node*> ready;
399  std::vector<int> pending_count(g.num_node_ids(), 0);
400
401  for (int i = 0; i < g.num_node_ids(); ++i) {
402    const Node* n = g.FindNodeId(i);
403    if (n == nullptr) continue;
404    pending_count[i] = n->in_edges().size();
405    if (n->IsMerge()) {
406      // While-loop cycles are legal cycles so we manually adjust the
407      // pending_count to make sure that the loop is visited.
408      for (const Edge* e : n->in_edges()) {
409        if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
410          pending_count[i]--;
411        }
412      }
413    }
414    if (pending_count[i] == 0) {
415      ready.push_back(n);
416    }
417  }
418
419  int processed = 0;
420  while (!ready.empty()) {
421    const Node* node = ready.back();
422    ready.pop_back();
423    ++processed;
424
425    for (const Edge* out : node->out_edges()) {
426      const int output_id = out->dst()->id();
427      pending_count[output_id]--;
428      if (pending_count[output_id] == 0) {
429        ready.push_back(out->dst());
430      }
431    }
432  }
433
434  if (processed < g.num_nodes()) {
435    std::vector<string> nodes_in_cycle;
436    for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
437         ++i) {
438      if (pending_count[i] != 0) {
439        nodes_in_cycle.push_back(g.FindNodeId(i)->name());
440      }
441    }
442    return errors::InvalidArgument(
443        "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
444        " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
445  }
446  return Status::OK();
447}
448}  // namespace
449}  // namespace tensorflow
450
451extern "C" {
452
453void TF_Reset(const TF_SessionOptions* opt, const char** containers,
454              int ncontainers, TF_Status* status) {
455  tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
456}
457
458}  // end extern "C"
459
460namespace tensorflow {
461
462Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
463  if (src->dtype == TF_RESOURCE) {
464    if (src->shape.dims() != 0) {
465      return InvalidArgument(
466          "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
467          "shape ",
468          src->shape.DebugString());
469    }
470    *dst = Tensor(DT_RESOURCE, src->shape);
471    if (!dst->scalar<ResourceHandle>()().ParseFromString(
472            string(static_cast<const char*>(TF_TensorData(src)),
473                   TF_TensorByteSize(src)))) {
474      return InvalidArgument(
475          "Malformed TF_RESOUCE tensor: unable to parse resource handle");
476    }
477    return Status::OK();
478  }
479  if (src->dtype != TF_STRING) {
480    *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
481    return Status::OK();
482  }
483  // TF_STRING tensors require copying since Tensor class expects a sequence of
484  // string objects.
485  const tensorflow::int64 num_elements = src->shape.num_elements();
486  const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
487  const size_t src_size = TF_TensorByteSize(src);
488  if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
489      num_elements) {
490    return InvalidArgument(
491        "Malformed TF_STRING tensor; too short to hold number of elements");
492  }
493  const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
494  const char* limit = input + src_size;
495
496  *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
497  auto dstarray = dst->flat<string>();
498  for (tensorflow::int64 i = 0; i < num_elements; ++i) {
499    tensorflow::uint64 offset =
500        reinterpret_cast<const tensorflow::uint64*>(input)[i];
501    if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
502      return InvalidArgument("Malformed TF_STRING tensor; element ", i,
503                             " out of range");
504    }
505    size_t len;
506    const char* p;
507    const char* srcp = data_start + offset;
508    Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
509    if (!status.ok()) return status;
510    dstarray(i).assign(p, len);
511  }
512  return Status::OK();
513}
514
515// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
516// result in a zero-sized tensor.
517static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
518  static char empty;
519  tensorflow::int64 nelems = 1;
520  std::vector<tensorflow::int64> dims;
521  for (int i = 0; i < shape.dims(); ++i) {
522    dims.push_back(shape.dim_size(i));
523    nelems *= shape.dim_size(i);
524  }
525  CHECK_EQ(nelems, 0);
526  static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
527                "64-bit int types should match in size");
528  return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
529                      shape.dims(), reinterpret_cast<void*>(&empty), 0,
530                      [](void*, size_t, void*) {}, nullptr);
531}
532
533// Non-static for testing.
534TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
535                               TF_Status* status) {
536  if (!src.IsInitialized()) {
537    status->status = FailedPrecondition(
538        "attempt to use a tensor with an uninitialized value");
539    return nullptr;
540  }
541  if (src.NumElements() == 0) {
542    return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
543  }
544  if (src.dtype() == DT_RESOURCE) {
545    if (src.shape().dims() != 0) {
546      status->status = InvalidArgument(
547          "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
548          src.shape().DebugString(),
549          "). Please file a bug at "
550          "https://github.com/tensorflow/tensorflow/issues/new, "
551          "ideally with a "
552          "short code snippet that reproduces this error.");
553      return nullptr;
554    }
555    const string str = src.scalar<ResourceHandle>()().SerializeAsString();
556    TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
557    std::memcpy(TF_TensorData(t), str.c_str(), str.size());
558    return t;
559  }
560  if (src.dtype() != DT_STRING) {
561    TensorBuffer* buf = TensorCApi::Buffer(src);
562    buf->Ref();
563    return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
564                         buf};
565  }
566  // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
567  // encoded sequence of strings.
568
569  // Compute bytes needed for encoding.
570  size_t size = 0;
571  const auto& srcarray = src.flat<string>();
572  for (int i = 0; i < srcarray.size(); ++i) {
573    const string& s = srcarray(i);
574    // uint64 starting_offset, TF_StringEncode-d string.
575    size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
576  }
577
578  // Encode all strings.
579  char* base = new char[size];
580  char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
581  char* dst = data_start;  // Where next string is encoded.
582  size_t dst_len = size - static_cast<size_t>(data_start - base);
583  tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
584  for (int i = 0; i < srcarray.size(); ++i) {
585    *offsets = (dst - data_start);
586    offsets++;
587    const string& s = srcarray(i);
588    size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
589    if (!status->status.ok()) {
590      status->status = InvalidArgument(
591          "invalid string tensor encoding (string #", i, " of ",
592          srcarray.size(), "): ", status->status.error_message());
593      delete[] base;
594      return nullptr;
595    }
596    dst += consumed;
597    dst_len -= consumed;
598  }
599  if (dst != base + size) {
600    status->status = InvalidArgument(
601        "invalid string tensor encoding (decoded ", (dst - base),
602        " bytes, but the tensor is encoded in ", size, " bytes");
603    delete[] base;
604    return nullptr;
605  }
606
607  auto dims = src.shape().dim_sizes();
608  std::vector<tensorflow::int64> dimvec(dims.size());
609  for (size_t i = 0; i < dims.size(); ++i) {
610    dimvec[i] = dims[i];
611  }
612  static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
613                "64-bit int types should match in size");
614  return TF_NewTensor(TF_STRING,
615                      reinterpret_cast<const int64_t*>(dimvec.data()),
616                      dimvec.size(), base, size, DeleteArray, base);
617}
618
619Status MessageToBuffer(const tensorflow::protobuf::Message& in,
620                       TF_Buffer* out) {
621  if (out->data != nullptr) {
622    return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
623  }
624  const size_t proto_size = in.ByteSizeLong();
625  void* buf = tensorflow::port::Malloc(proto_size);
626  if (buf == nullptr) {
627    return tensorflow::errors::ResourceExhausted(
628        "Failed to allocate memory to serialize message of type '",
629        in.GetTypeName(), "' and size ", proto_size);
630  }
631  in.SerializeToArray(buf, proto_size);
632  out->data = buf;
633  out->length = proto_size;
634  out->data_deallocator = [](void* data, size_t length) {
635    tensorflow::port::Free(data);
636  };
637  return Status::OK();
638}
639
640void RecordMutation(TF_Graph* graph, const TF_Operation& op,
641                    const char* mutation_type)
642    EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
643  // If any session has already run this node_id, mark this session as
644  // unrunnable.
645  for (auto it : graph->sessions) {
646    if (it.first->last_num_graph_nodes > op.node.id()) {
647      it.second = FailedPrecondition(
648          "Operation '", op.node.DebugString(), "' was changed by ",
649          mutation_type,
650          " after it was run by a session. Nodes can be mutated "
651          "only before they are executed by a session. Either don't modify "
652          "nodes after running them or create a new session.");
653    }
654  }
655}
656
657namespace {
658
659// Helper method that creates a shape handle for a shape described by dims.
660tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
661    tensorflow::shape_inference::InferenceContext* ic, int num_dims,
662    const int64_t* dims) {
663  if (num_dims != -1) {
664    std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
665    dim_vec.reserve(num_dims);
666    for (int i = 0; i < num_dims; ++i) {
667      dim_vec.push_back(ic->MakeDim(dims[i]));
668    }
669    return ic->MakeShape(dim_vec);
670  } else {
671    return ic->UnknownShape();
672  }
673}
674
675}  // namespace
676
677void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
678                                           int num_shapes_and_types,
679                                           const int64_t** shapes,
680                                           const int* ranks,
681                                           const TF_DataType* types,
682                                           TF_Status* status) {
683  Node* node = &output.oper->node;
684
685  mutex_lock l(graph->mu);
686  tensorflow::shape_inference::InferenceContext* ic =
687      graph->refiner.GetContext(node);
688  if (ic == nullptr) {
689    status->status =
690        InvalidArgument("Node ", node->name(), " was not found in the graph");
691    return;
692  }
693
694  auto shape_and_type_vec =
695      std::vector<tensorflow::shape_inference::ShapeAndType>(
696          num_shapes_and_types);
697  for (int i = 0; i < num_shapes_and_types; ++i) {
698    tensorflow::shape_inference::ShapeHandle shape_handle =
699        ShapeHandleFromDims(ic, ranks[i], shapes[i]);
700    shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
701        shape_handle, static_cast<DataType>(types[i]));
702  }
703
704  ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
705}
706
707// Helpers for loading a TensorFlow plugin (a .so file).
708Status LoadLibrary(const char* library_filename, void** result,
709                   const void** buf, size_t* len);
710
711}  // namespace tensorflow
712
713static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
714                         TF_Status* status) {
715  status->status = Status::OK();
716  for (int i = 0; i < noutputs; ++i) {
717    c_outputs[i] = nullptr;
718  }
719}
720
721static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
722                          std::vector<std::pair<string, Tensor>>* input_pairs,
723                          TF_Status* status) {
724  const int ninputs = input_pairs->size();
725  for (int i = 0; i < ninputs; ++i) {
726    status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
727    if (!status->status.ok()) return false;
728  }
729  return true;
730}
731
732static void TF_Run_Helper(
733    Session* session, const char* handle, const TF_Buffer* run_options,
734    // Input tensors
735    const std::vector<std::pair<string, Tensor>>& input_pairs,
736    // Output tensors
737    const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
738    // Target nodes
739    const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
740    TF_Status* status) {
741  const int noutputs = output_tensor_names.size();
742  std::vector<Tensor> outputs(noutputs);
743  Status result;
744
745  if (handle == nullptr) {
746    RunOptions run_options_proto;
747    if (run_options != nullptr && !run_options_proto.ParseFromArray(
748                                      run_options->data, run_options->length)) {
749      status->status = InvalidArgument("Unparseable RunOptions proto");
750      return;
751    }
752    if (run_metadata != nullptr && run_metadata->data != nullptr) {
753      status->status =
754          InvalidArgument("Passing non-empty run_metadata is invalid.");
755      return;
756    }
757
758    RunMetadata run_metadata_proto;
759    result = session->Run(run_options_proto, input_pairs, output_tensor_names,
760                          target_oper_names, &outputs, &run_metadata_proto);
761
762    // Serialize back to upstream client, who now owns the new buffer
763    if (run_metadata != nullptr) {
764      status->status = MessageToBuffer(run_metadata_proto, run_metadata);
765      if (!status->status.ok()) return;
766    }
767  } else {
768    // NOTE(zongheng): PRun does not support RunOptions yet.
769    result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
770  }
771  if (!result.ok()) {
772    status->status = result;
773    return;
774  }
775
776  // Store results in c_outputs[]
777  for (int i = 0; i < noutputs; ++i) {
778    const Tensor& src = outputs[i];
779    if (!src.IsInitialized() || src.NumElements() == 0) {
780      c_outputs[i] =
781          EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
782      continue;
783    }
784    c_outputs[i] = TF_TensorFromTensor(src, status);
785    if (!status->status.ok()) return;
786  }
787}
788
789extern "C" {
790
791void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
792            // Input tensors
793            const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
794            // Output tensors
795            const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
796            // Target nodes
797            const char** c_target_oper_names, int ntargets,
798            TF_Buffer* run_metadata, TF_Status* status) {
799  TF_Run_Setup(noutputs, c_outputs, status);
800  std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
801  if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
802  for (int i = 0; i < ninputs; ++i) {
803    input_pairs[i].first = c_input_names[i];
804  }
805  std::vector<string> output_names(noutputs);
806  for (int i = 0; i < noutputs; ++i) {
807    output_names[i] = c_output_names[i];
808  }
809  std::vector<string> target_oper_names(ntargets);
810  for (int i = 0; i < ntargets; ++i) {
811    target_oper_names[i] = c_target_oper_names[i];
812  }
813  TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
814                c_outputs, target_oper_names, run_metadata, status);
815}
816
817void TF_PRunSetup(TF_DeprecatedSession* s,
818                  // Input names
819                  const char** c_input_names, int ninputs,
820                  // Output names
821                  const char** c_output_names, int noutputs,
822                  // Target nodes
823                  const char** c_target_oper_names, int ntargets,
824                  const char** handle, TF_Status* status) {
825  *handle = nullptr;
826
827  std::vector<string> input_names(ninputs);
828  std::vector<string> output_names(noutputs);
829  std::vector<string> target_oper_names(ntargets);
830  for (int i = 0; i < ninputs; ++i) {
831    input_names[i] = c_input_names[i];
832  }
833  for (int i = 0; i < noutputs; ++i) {
834    output_names[i] = c_output_names[i];
835  }
836  for (int i = 0; i < ntargets; ++i) {
837    target_oper_names[i] = c_target_oper_names[i];
838  }
839  string new_handle;
840  status->status = s->session->PRunSetup(input_names, output_names,
841                                         target_oper_names, &new_handle);
842  if (status->status.ok()) {
843    char* buf = new char[new_handle.size() + 1];
844    memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
845    *handle = buf;
846  }
847}
848
849void TF_PRun(TF_DeprecatedSession* s, const char* handle,
850             // Input tensors
851             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
852             // Output tensors
853             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
854             // Target nodes
855             const char** c_target_oper_names, int ntargets,
856             TF_Status* status) {
857  TF_Run_Setup(noutputs, c_outputs, status);
858  std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
859  if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
860  for (int i = 0; i < ninputs; ++i) {
861    input_pairs[i].first = c_input_names[i];
862  }
863
864  std::vector<string> output_names(noutputs);
865  for (int i = 0; i < noutputs; ++i) {
866    output_names[i] = c_output_names[i];
867  }
868  std::vector<string> target_oper_names(ntargets);
869  for (int i = 0; i < ntargets; ++i) {
870    target_oper_names[i] = c_target_oper_names[i];
871  }
872  TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
873                c_outputs, target_oper_names, nullptr, status);
874}
875
876TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
877  TF_Library* lib_handle = new TF_Library;
878  status->status = tensorflow::LoadLibrary(
879      library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
880      &lib_handle->op_list.length);
881  if (!status->status.ok()) {
882    delete lib_handle;
883    return nullptr;
884  }
885  return lib_handle;
886}
887
888TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
889
890void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
891  tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
892  delete lib_handle;
893}
894
895TF_Buffer* TF_GetAllOpList() {
896  std::vector<tensorflow::OpDef> op_defs;
897  tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
898  tensorflow::OpList op_list;
899  for (const auto& op : op_defs) {
900    *(op_list.add_op()) = op;
901  }
902  TF_Buffer* ret = TF_NewBuffer();
903  TF_CHECK_OK(MessageToBuffer(op_list, ret));
904  return ret;
905}
906
907// --------------------------------------------------------------------------
908// ListDevices & SessionListDevices API
909
910void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
911
912TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
913  TF_DeviceList* response = new TF_DeviceList;
914  status->status = session->session->ListDevices(&response->response);
915  return response;
916}
917
918TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
919                                               TF_Status* status) {
920  TF_DeviceList* response = new TF_DeviceList;
921  status->status = session->session->ListDevices(&response->response);
922  return response;
923}
924
925int TF_DeviceListCount(const TF_DeviceList* list) {
926  return list->response.size();
927}
928
929#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
930  return_type method_name(const TF_DeviceList* list, const int index,     \
931                          TF_Status* status) {                            \
932    if (list == nullptr) {                                                \
933      status->status = InvalidArgument("list is null!");                  \
934      return err_val;                                                     \
935    }                                                                     \
936    if (index < 0 || index >= list->response.size()) {                    \
937      status->status = InvalidArgument("index out of bounds");            \
938      return err_val;                                                     \
939    }                                                                     \
940    status->status = Status::OK();                                        \
941    return list->response[index].accessor;                                \
942  }
943
944TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
945TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
946                     nullptr);
947TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
948
949#undef TF_DEVICELIST_METHOD
950
951}  // end extern "C"
952
953// --------------------------------------------------------------------------
954// New Graph and Session API
955
956// Helper functions -----------------------------------------------------------
957
958namespace {
959
960TF_Operation* ToOperation(Node* node) {
961  return static_cast<TF_Operation*>(static_cast<void*>(node));
962}
963
964string OutputName(const TF_Output& output) {
965  return StrCat(output.oper->node.name(), ":", output.index);
966}
967
968const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
969                                          const char* attr_name,
970                                          TF_Status* status) {
971  const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
972  if (attr == nullptr) {
973    status->status = InvalidArgument("Operation '", oper->node.name(),
974                                     "' has no attr named '", attr_name, "'.");
975  }
976  return attr;
977}
978
979TensorId ToTensorId(const TF_Output& output) {
980  return TensorId(output.oper->node.name(), output.index);
981}
982
983#ifndef __ANDROID__
984std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
985                                                     int n) {
986  std::vector<tensorflow::Output> outputs(n);
987  for (int i = 0; i < n; ++i) {
988    outputs[i] =
989        tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
990  }
991  return outputs;
992}
993
994void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
995                          TF_Output* tf_outputs) {
996  for (int i = 0; i < outputs.size(); i++) {
997    tf_outputs[i].oper = ToOperation(outputs[i].node());
998    tf_outputs[i].index = outputs[i].index();
999  }
1000}
1001#endif  // __ANDROID__
1002
1003}  // namespace
1004
1005// Shape functions -----------------------------------------------------------
1006
1007void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1008                            const int64_t* dims, const int num_dims,
1009                            TF_Status* status) {
1010  Node* node = &output.oper->node;
1011
1012  mutex_lock l(graph->mu);
1013  tensorflow::shape_inference::InferenceContext* ic =
1014      graph->refiner.GetContext(node);
1015  if (ic == nullptr) {
1016    status->status =
1017        InvalidArgument("Node ", node->name(), " was not found in the graph");
1018    return;
1019  }
1020  tensorflow::shape_inference::ShapeHandle new_shape =
1021      tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1022  status->status = graph->refiner.SetShape(node, output.index, new_shape);
1023}
1024
1025int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1026                             TF_Status* status) {
1027  Node* node = &output.oper->node;
1028
1029  mutex_lock l(graph->mu);
1030  tensorflow::shape_inference::InferenceContext* ic =
1031      graph->refiner.GetContext(node);
1032  if (ic == nullptr) {
1033    status->status =
1034        InvalidArgument("Node ", node->name(), " was not found in the graph");
1035    return -1;
1036  }
1037
1038  tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1039
1040  // Unknown rank means the number of dimensions is -1.
1041  if (!ic->RankKnown(shape)) {
1042    return -1;
1043  }
1044
1045  return ic->Rank(shape);
1046}
1047
1048void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1049                            int num_dims, TF_Status* status) {
1050  Node* node = &output.oper->node;
1051
1052  mutex_lock l(graph->mu);
1053  tensorflow::shape_inference::InferenceContext* ic =
1054      graph->refiner.GetContext(node);
1055  if (ic == nullptr) {
1056    status->status =
1057        InvalidArgument("Node ", node->name(), " was not found in the graph");
1058    return;
1059  }
1060
1061  tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1062
1063  int rank = -1;
1064  if (ic->RankKnown(shape)) {
1065    rank = ic->Rank(shape);
1066  }
1067
1068  if (num_dims != rank) {
1069    status->status = InvalidArgument("Expected rank is ", num_dims,
1070                                     " but actual rank is ", rank);
1071    return;
1072  }
1073
1074  if (num_dims == 0) {
1075    // Output shape is a scalar.
1076    return;
1077  }
1078
1079  // Rank is greater than 0, so fill in the values, if known, and
1080  // -1 for unknown values.
1081  for (int i = 0; i < num_dims; ++i) {
1082    auto dim = ic->Dim(shape, i);
1083    tensorflow::int64 value = -1;
1084    if (ic->ValueKnown(dim)) {
1085      value = ic->Value(dim);
1086    }
1087    dims[i] = value;
1088  }
1089}
1090
1091// TF_OperationDescription functions ------------------------------------------
1092
1093extern "C" {
1094
1095static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1096                                                      const char* op_type,
1097                                                      const char* oper_name)
1098    EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1099  return new TF_OperationDescription(graph, op_type, oper_name);
1100}
1101
1102TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1103                                         const char* oper_name) {
1104  mutex_lock l(graph->mu);
1105  return TF_NewOperationLocked(graph, op_type, oper_name);
1106}
1107
1108void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1109  desc->node_builder.Device(device);
1110}
1111
1112void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1113  desc->node_builder.Input(&input.oper->node, input.index);
1114}
1115
1116void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1117                     int num_inputs) {
1118  std::vector<NodeBuilder::NodeOut> input_list;
1119  input_list.reserve(num_inputs);
1120  for (int i = 0; i < num_inputs; ++i) {
1121    input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1122  }
1123  desc->node_builder.Input(input_list);
1124}
1125
1126void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1127  desc->node_builder.ControlInput(&input->node);
1128}
1129
1130void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1131  desc->colocation_constraints.emplace(
1132      StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1133}
1134
1135void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1136                      const void* value, size_t length) {
1137  tensorflow::StringPiece s(static_cast<const char*>(value), length);
1138  desc->node_builder.Attr(attr_name, s);
1139}
1140
1141void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1142                          const void* const* values, const size_t* lengths,
1143                          int num_values) {
1144  if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1145    desc->colocation_constraints.clear();
1146    for (int i = 0; i < num_values; ++i) {
1147      desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1148                                           lengths[i]);
1149    }
1150  } else {
1151    std::vector<tensorflow::StringPiece> v;
1152    v.reserve(num_values);
1153    for (int i = 0; i < num_values; ++i) {
1154      v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1155    }
1156    desc->node_builder.Attr(attr_name, v);
1157  }
1158}
1159
1160void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1161                   int64_t value) {
1162  static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1163                "64-bit int types should match in size");
1164  desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1165}
1166
1167void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1168                       const int64_t* values, int num_values) {
1169  static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1170                "64-bit int types should match in size");
1171  desc->node_builder.Attr(
1172      attr_name,
1173      ArraySlice<const tensorflow::int64>(
1174          reinterpret_cast<const tensorflow::int64*>(values), num_values));
1175}
1176
1177void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1178                     float value) {
1179  desc->node_builder.Attr(attr_name, value);
1180}
1181
1182void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1183                         const float* values, int num_values) {
1184  desc->node_builder.Attr(attr_name,
1185                          ArraySlice<const float>(values, num_values));
1186}
1187
1188void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1189                    unsigned char value) {
1190  desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1191}
1192
1193void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1194                        const unsigned char* values, int num_values) {
1195  std::unique_ptr<bool[]> b(new bool[num_values]);
1196  for (int i = 0; i < num_values; ++i) {
1197    b[i] = values[i];
1198  }
1199  desc->node_builder.Attr(attr_name,
1200                          ArraySlice<const bool>(b.get(), num_values));
1201}
1202
1203void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1204                    TF_DataType value) {
1205  desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1206}
1207
1208void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1209                        const TF_DataType* values, int num_values) {
1210  desc->node_builder.Attr(
1211      attr_name, ArraySlice<const DataType>(
1212                     reinterpret_cast<const DataType*>(values), num_values));
1213}
1214
1215void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1216                        const char* value, size_t length) {
1217  tensorflow::NameAttrList func_name;
1218  func_name.set_name(std::string(value, value + length));
1219  desc->node_builder.Attr(attr_name, func_name);
1220}
1221
1222void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1223                     const int64_t* dims, int num_dims) {
1224  PartialTensorShape shape;
1225  if (num_dims >= 0) {
1226    static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1227                  "64-bit int types should match in size");
1228    shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1229        reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1230  }
1231  desc->node_builder.Attr(attr_name, shape);
1232}
1233
1234void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1235                         const int64_t* const* dims, const int* num_dims,
1236                         int num_shapes) {
1237  std::vector<PartialTensorShape> shapes;
1238  shapes.reserve(num_shapes);
1239  for (int i = 0; i < num_shapes; ++i) {
1240    if (num_dims[i] < 0) {
1241      shapes.emplace_back();
1242    } else {
1243      static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1244                    "64-bit int types should match in size");
1245      shapes.emplace_back(ArraySlice<tensorflow::int64>(
1246          reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1247    }
1248  }
1249  desc->node_builder.Attr(attr_name, shapes);
1250}
1251
1252void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1253                                const char* attr_name, const void* proto,
1254                                size_t proto_len, TF_Status* status) {
1255  // shape.ParseFromArray takes an int as length, this function takes size_t,
1256  // make sure there is no information loss.
1257  if (proto_len > std::numeric_limits<int>::max()) {
1258    status->status = InvalidArgument(
1259        "proto_len (", proto_len,
1260        " bytes) is too large to be parsed by the protocol buffer library");
1261    return;
1262  }
1263  TensorShapeProto shape;
1264  if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1265    desc->node_builder.Attr(attr_name, shape);
1266    status->status = Status::OK();
1267  } else {
1268    status->status = InvalidArgument("Unparseable TensorShapeProto");
1269  }
1270}
1271
1272void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1273                                    const char* attr_name,
1274                                    const void* const* protos,
1275                                    const size_t* proto_lens, int num_shapes,
1276                                    TF_Status* status) {
1277  std::vector<TensorShapeProto> shapes;
1278  shapes.resize(num_shapes);
1279  for (int i = 0; i < num_shapes; ++i) {
1280    if (proto_lens[i] > std::numeric_limits<int>::max()) {
1281      status->status = InvalidArgument(
1282          "length of element ", i, " in the list (", proto_lens[i],
1283          " bytes) is too large to be parsed by the protocol buffer library");
1284      return;
1285    }
1286    if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1287      status->status =
1288          InvalidArgument("Unparseable TensorShapeProto at index ", i);
1289      return;
1290    }
1291  }
1292  desc->node_builder.Attr(attr_name, shapes);
1293  status->status = Status::OK();
1294}
1295
1296void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1297                      TF_Tensor* value, TF_Status* status) {
1298  Tensor t;
1299  status->status = TF_TensorToTensor(value, &t);
1300  if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1301}
1302
1303void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1304                          TF_Tensor* const* values, int num_values,
1305                          TF_Status* status) {
1306  status->status = Status::OK();
1307  std::vector<Tensor> t;
1308  t.reserve(num_values);
1309
1310  for (int i = 0; i < num_values && status->status.ok(); ++i) {
1311    Tensor v;
1312    status->status = TF_TensorToTensor(values[i], &v);
1313    t.emplace_back(v);
1314  }
1315
1316  if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1317}
1318
1319void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1320                          const void* proto, size_t proto_len,
1321                          TF_Status* status) {
1322  tensorflow::AttrValue attr_value;
1323  if (!attr_value.ParseFromArray(proto, proto_len)) {
1324    status->status = InvalidArgument("Unparseable AttrValue proto");
1325    return;
1326  }
1327
1328  if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1329    if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1330        attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1331      status->status =
1332          InvalidArgument("Expected \"list\" field for \"",
1333                          tensorflow::kColocationAttrName, "\" attribute");
1334      return;
1335    }
1336    desc->colocation_constraints.clear();
1337    for (const string& location : attr_value.list().s()) {
1338      desc->colocation_constraints.insert(location);
1339    }
1340  } else {
1341    desc->node_builder.Attr(attr_name, attr_value);
1342  }
1343
1344  status->status = Status::OK();
1345}
1346
1347static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1348                                              TF_Status* status)
1349    EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1350  Node* ret = nullptr;
1351
1352  if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1353    status->status = InvalidArgument("Duplicate node name in graph: '",
1354                                     desc->node_builder.node_name(), "'");
1355  } else {
1356    if (!desc->colocation_constraints.empty()) {
1357      desc->node_builder.Attr(
1358          tensorflow::kColocationAttrName,
1359          std::vector<string>(desc->colocation_constraints.begin(),
1360                              desc->colocation_constraints.end()));
1361    }
1362    status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1363
1364    if (status->status.ok()) {
1365      // Run shape inference function for newly added node.
1366      status->status = desc->graph->refiner.AddNode(ret);
1367    }
1368    if (status->status.ok()) {
1369      // Add the node to the name-to-node mapping.
1370      desc->graph->name_map[ret->name()] = ret;
1371    } else if (ret != nullptr) {
1372      desc->graph->graph.RemoveNode(ret);
1373      ret = nullptr;
1374    }
1375  }
1376
1377  delete desc;
1378
1379  return ToOperation(ret);
1380}
1381
1382TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1383                                 TF_Status* status) {
1384  mutex_lock l(desc->graph->mu);
1385  return TF_FinishOperationLocked(desc, status);
1386}
1387
1388// TF_Operation functions
1389// ----------------------------------------------------------
1390
1391const char* TF_OperationName(TF_Operation* oper) {
1392  return oper->node.name().c_str();
1393}
1394
1395const char* TF_OperationOpType(TF_Operation* oper) {
1396  return oper->node.type_string().c_str();
1397}
1398
1399const char* TF_OperationDevice(TF_Operation* oper) {
1400  return oper->node.requested_device().c_str();
1401}
1402
1403int TF_OperationNumOutputs(TF_Operation* oper) {
1404  return oper->node.num_outputs();
1405}
1406
1407TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1408  return static_cast<TF_DataType>(
1409      oper_out.oper->node.output_type(oper_out.index));
1410}
1411
1412int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1413                                 TF_Status* status) {
1414  NameRangeMap name_ranges;
1415  status->status =
1416      NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1417  if (!status->status.ok()) return -1;
1418  auto iter = name_ranges.find(arg_name);
1419  if (iter == name_ranges.end()) {
1420    status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1421    return -1;
1422  }
1423  return iter->second.second - iter->second.first;
1424}
1425
1426int TF_OperationNumInputs(TF_Operation* oper) {
1427  return oper->node.num_inputs();
1428}
1429
1430TF_DataType TF_OperationInputType(TF_Input oper_in) {
1431  return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1432}
1433
1434int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1435                                TF_Status* status) {
1436  NameRangeMap name_ranges;
1437  status->status =
1438      NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1439  if (!status->status.ok()) return -1;
1440  auto iter = name_ranges.find(arg_name);
1441  if (iter == name_ranges.end()) {
1442    status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1443    return -1;
1444  }
1445  return iter->second.second - iter->second.first;
1446}
1447
1448TF_Output TF_OperationInput(TF_Input oper_in) {
1449  const tensorflow::Edge* edge;
1450  Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1451  if (!s.ok()) {
1452    return {nullptr, -1};
1453  }
1454
1455  return {ToOperation(edge->src()), edge->src_output()};
1456}
1457
1458int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1459  int count = 0;
1460  for (const auto* edge : oper_out.oper->node.out_edges()) {
1461    if (edge->src_output() == oper_out.index) {
1462      ++count;
1463    }
1464  }
1465  return count;
1466}
1467
1468int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1469                                int max_consumers) {
1470  int count = 0;
1471  for (const auto* edge : oper_out.oper->node.out_edges()) {
1472    if (edge->src_output() == oper_out.index) {
1473      if (count < max_consumers) {
1474        consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1475      }
1476      ++count;
1477    }
1478  }
1479  return count;
1480}
1481
1482int TF_OperationNumControlInputs(TF_Operation* oper) {
1483  int count = 0;
1484  for (const auto* edge : oper->node.in_edges()) {
1485    if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1486      ++count;
1487    }
1488  }
1489  return count;
1490}
1491
1492int TF_OperationGetControlInputs(TF_Operation* oper,
1493                                 TF_Operation** control_inputs,
1494                                 int max_control_inputs) {
1495  int count = 0;
1496  for (const auto* edge : oper->node.in_edges()) {
1497    if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1498      if (count < max_control_inputs) {
1499        control_inputs[count] = ToOperation(edge->src());
1500      }
1501      ++count;
1502    }
1503  }
1504  return count;
1505}
1506
1507int TF_OperationNumControlOutputs(TF_Operation* oper) {
1508  int count = 0;
1509  for (const auto* edge : oper->node.out_edges()) {
1510    if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1511      ++count;
1512    }
1513  }
1514  return count;
1515}
1516
1517int TF_OperationGetControlOutputs(TF_Operation* oper,
1518                                  TF_Operation** control_outputs,
1519                                  int max_control_outputs) {
1520  int count = 0;
1521  for (const auto* edge : oper->node.out_edges()) {
1522    if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1523      if (count < max_control_outputs) {
1524        control_outputs[count] = ToOperation(edge->dst());
1525      }
1526      ++count;
1527    }
1528  }
1529  return count;
1530}
1531
1532TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1533                                            const char* attr_name,
1534                                            TF_Status* status) {
1535  TF_AttrMetadata metadata;
1536  const auto* attr = GetAttrValue(oper, attr_name, status);
1537  if (!status->status.ok()) return metadata;
1538  switch (attr->value_case()) {
1539#define SINGLE_CASE(kK, attr_type, size_expr) \
1540  case tensorflow::AttrValue::kK:             \
1541    metadata.is_list = 0;                     \
1542    metadata.list_size = -1;                  \
1543    metadata.type = attr_type;                \
1544    metadata.total_size = size_expr;          \
1545    break;
1546
1547    SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1548    SINGLE_CASE(kI, TF_ATTR_INT, -1);
1549    SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1550    SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1551    SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1552    SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1553                attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1554    SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1555#undef SINGLE_CASE
1556
1557    case tensorflow::AttrValue::kList:
1558      metadata.is_list = 1;
1559      metadata.list_size = 0;
1560      metadata.total_size = -1;
1561#define LIST_CASE(field, attr_type, ...)              \
1562  if (attr->list().field##_size() > 0) {              \
1563    metadata.type = attr_type;                        \
1564    metadata.list_size = attr->list().field##_size(); \
1565    __VA_ARGS__;                                      \
1566    break;                                            \
1567  }
1568
1569      LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
1570                for (int i = 0; i < attr->list().s_size();
1571                     ++i) { metadata.total_size += attr->list().s(i).size(); });
1572      LIST_CASE(i, TF_ATTR_INT);
1573      LIST_CASE(f, TF_ATTR_FLOAT);
1574      LIST_CASE(b, TF_ATTR_BOOL);
1575      LIST_CASE(type, TF_ATTR_TYPE);
1576      LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1577                for (int i = 0; i < attr->list().shape_size(); ++i) {
1578                  const auto& s = attr->list().shape(i);
1579                  metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1580                });
1581      LIST_CASE(tensor, TF_ATTR_TENSOR);
1582      LIST_CASE(tensor, TF_ATTR_FUNC);
1583#undef LIST_CASE
1584      // All lists empty, determine the type from the OpDef.
1585      if (metadata.list_size == 0) {
1586        for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1587          const auto& a = oper->node.op_def().attr(i);
1588          if (a.name().compare(attr_name) != 0) continue;
1589          const string& typestr = a.type();
1590          if (typestr == "list(string)") {
1591            metadata.type = TF_ATTR_STRING;
1592          } else if (typestr == "list(int)") {
1593            metadata.type = TF_ATTR_INT;
1594          } else if (typestr == "list(float)") {
1595            metadata.type = TF_ATTR_FLOAT;
1596          } else if (typestr == "list(bool)") {
1597            metadata.type = TF_ATTR_BOOL;
1598          } else if (typestr == "list(type)") {
1599            metadata.type = TF_ATTR_TYPE;
1600          } else if (typestr == "list(shape)") {
1601            metadata.type = TF_ATTR_SHAPE;
1602          } else if (typestr == "list(tensor)") {
1603            metadata.type = TF_ATTR_TENSOR;
1604          } else if (typestr == "list(func)") {
1605            metadata.type = TF_ATTR_FUNC;
1606          } else {
1607            status->status = InvalidArgument(
1608                "Attribute '", attr_name,
1609                "' has an empty value of an unrecognized type '", typestr, "'");
1610            return metadata;
1611          }
1612        }
1613      }
1614      break;
1615
1616    case tensorflow::AttrValue::kPlaceholder:
1617      metadata.is_list = 0;
1618      metadata.list_size = -1;
1619      metadata.type = TF_ATTR_PLACEHOLDER;
1620      metadata.total_size = -1;
1621      break;
1622
1623    case tensorflow::AttrValue::kFunc:
1624      metadata.is_list = 0;
1625      metadata.list_size = -1;
1626      metadata.type = TF_ATTR_FUNC;
1627      metadata.total_size = -1;
1628      break;
1629
1630    case tensorflow::AttrValue::VALUE_NOT_SET:
1631      status->status =
1632          InvalidArgument("Attribute '", attr_name, "' has no value set");
1633      break;
1634  }
1635  return metadata;
1636}
1637
1638void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1639                               void* value, size_t max_length,
1640                               TF_Status* status) {
1641  const auto* attr = GetAttrValue(oper, attr_name, status);
1642  if (!status->status.ok()) return;
1643  if (attr->value_case() != tensorflow::AttrValue::kS) {
1644    status->status =
1645        InvalidArgument("Attribute '", attr_name, "' is not a string");
1646    return;
1647  }
1648  if (max_length <= 0) {
1649    return;
1650  }
1651  const auto& s = attr->s();
1652  std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1653}
1654
1655void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1656                                   void** values, size_t* lengths,
1657                                   int max_values, void* storage,
1658                                   size_t storage_size, TF_Status* status) {
1659  const auto* attr = GetAttrValue(oper, attr_name, status);
1660  if (!status->status.ok()) return;
1661  if (attr->value_case() != tensorflow::AttrValue::kList) {
1662    status->status =
1663        InvalidArgument("Value for '", attr_name, "' is not a list");
1664    return;
1665  }
1666  const auto len = std::min(max_values, attr->list().s_size());
1667  char* p = static_cast<char*>(storage);
1668  for (int i = 0; i < len; ++i) {
1669    const string& s = attr->list().s(i);
1670    values[i] = p;
1671    lengths[i] = s.size();
1672    if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1673      status->status = InvalidArgument(
1674          "Not enough storage to hold the requested list of strings");
1675      return;
1676    }
1677    memcpy(values[i], s.data(), s.size());
1678    p += s.size();
1679  }
1680}
1681
1682#define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1683  void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1684            TF_Status* status) {                                             \
1685    cpp_type v;                                                              \
1686    status->status =                                                         \
1687        tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1688    *value = static_cast<c_type>(v);                                         \
1689  }                                                                          \
1690  void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1691                  int max_values, TF_Status* status) {                       \
1692    const auto* attr = GetAttrValue(oper, attr_name, status);                \
1693    if (!status->status.ok()) return;                                        \
1694    if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1695      status->status =                                                       \
1696          InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1697      return;                                                                \
1698    }                                                                        \
1699    const auto len = std::min(max_values, attr->list().list_field##_size()); \
1700    for (int i = 0; i < len; ++i) {                                          \
1701      values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1702    }                                                                        \
1703  }
1704DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1705DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1706DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1707DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1708#undef DEFINE_GETATTR
1709
1710void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1711                              int64_t* value, int num_dims, TF_Status* status) {
1712  PartialTensorShape shape;
1713  status->status =
1714      tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1715  if (!status->status.ok()) return;
1716  auto len = std::min(shape.dims(), num_dims);
1717  for (int i = 0; i < len; ++i) {
1718    value[i] = shape.dim_size(i);
1719  }
1720}
1721
1722void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1723                                  int64_t** values, int* num_dims,
1724                                  int max_values, int64_t* storage,
1725                                  int storage_size, TF_Status* status) {
1726  std::vector<PartialTensorShape> shapes;
1727  status->status =
1728      tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1729  if (!status->status.ok()) return;
1730  auto len = std::min(static_cast<int>(shapes.size()), max_values);
1731  int64_t* p = storage;
1732  int storage_left = storage_size;
1733  for (int i = 0; i < len; ++i) {
1734    // shapes[i].dims() == -1 for shapes with an unknown rank.
1735    int64_t n = shapes[i].dims();
1736    num_dims[i] = n;
1737    values[i] = p;
1738    if (n < 0) {
1739      continue;
1740    }
1741    if (storage_left < n) {
1742      status->status = InvalidArgument(
1743          "Not enough storage to hold the requested list of shapes");
1744      return;
1745    }
1746    storage_left -= n;
1747    for (int j = 0; j < n; ++j, ++p) {
1748      *p = shapes[i].dim_size(j);
1749    }
1750  }
1751}
1752
1753void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1754                                         const char* attr_name,
1755                                         TF_Buffer* value, TF_Status* status) {
1756  const auto* attr = GetAttrValue(oper, attr_name, status);
1757  if (!status->status.ok()) return;
1758  if (attr->value_case() != tensorflow::AttrValue::kShape) {
1759    status->status =
1760        InvalidArgument("Value for '", attr_name, "' is not a shape.");
1761    return;
1762  }
1763  status->status = MessageToBuffer(attr->shape(), value);
1764}
1765
1766void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1767                                             const char* attr_name,
1768                                             TF_Buffer** values, int max_values,
1769                                             TF_Status* status) {
1770  const auto* attr = GetAttrValue(oper, attr_name, status);
1771  if (!status->status.ok()) return;
1772  if (attr->value_case() != tensorflow::AttrValue::kList) {
1773    status->status =
1774        InvalidArgument("Value for '", attr_name, "' is not a list");
1775    return;
1776  }
1777  const auto len = std::min(max_values, attr->list().shape_size());
1778  for (int i = 0; i < len; ++i) {
1779    values[i] = TF_NewBuffer();
1780    status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1781    if (!status->status.ok()) {
1782      // Delete everything allocated to far, the operation has failed.
1783      for (int j = 0; j <= i; ++j) {
1784        TF_DeleteBuffer(values[j]);
1785      }
1786      return;
1787    }
1788  }
1789}
1790
1791void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1792                               TF_Tensor** value, TF_Status* status) {
1793  *value = nullptr;
1794  Tensor t;
1795  status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1796  if (!status->status.ok()) return;
1797  *value = TF_TensorFromTensor(t, status);
1798}
1799
1800void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1801                                   TF_Tensor** values, int max_values,
1802                                   TF_Status* status) {
1803  std::vector<Tensor> ts;
1804  status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1805  if (!status->status.ok()) return;
1806  const auto len = std::min(max_values, static_cast<int>(ts.size()));
1807  for (int i = 0; i < len; ++i) {
1808    values[i] = TF_TensorFromTensor(ts[i], status);
1809  }
1810}
1811
1812void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1813                                   TF_Buffer* output_attr_value,
1814                                   TF_Status* status) {
1815  const auto* attr = GetAttrValue(oper, attr_name, status);
1816  if (!status->status.ok()) return;
1817  status->status = MessageToBuffer(*attr, output_attr_value);
1818}
1819
1820void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1821                           TF_Status* status) {
1822  status->status = MessageToBuffer(oper->node.def(), output_node_def);
1823}
1824
1825// TF_Graph functions ---------------------------------------------------------
1826
1827TF_Graph::TF_Graph()
1828    : graph(tensorflow::OpRegistry::Global()),
1829      refiner(graph.versions().producer(), graph.op_registry()),
1830      delete_requested(false),
1831      parent(nullptr),
1832      parent_inputs(nullptr) {}
1833
1834TF_Graph* TF_NewGraph() { return new TF_Graph; }
1835
1836void TF_DeleteGraph(TF_Graph* g) {
1837  g->mu.lock();
1838  g->delete_requested = true;
1839  const bool del = g->sessions.empty();
1840  g->mu.unlock();
1841  if (del) delete g;
1842}
1843
1844TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1845  mutex_lock l(graph->mu);
1846  auto iter = graph->name_map.find(oper_name);
1847  if (iter == graph->name_map.end()) {
1848    return nullptr;
1849  } else {
1850    return ToOperation(iter->second);
1851  }
1852}
1853
1854TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1855  if (*pos == 0) {
1856    // Advance past the first sentinel nodes in every graph (the source & sink).
1857    *pos += 2;
1858  } else {
1859    // Advance to the next node.
1860    *pos += 1;
1861  }
1862
1863  mutex_lock l(graph->mu);
1864  while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1865    Node* node = graph->graph.FindNodeId(*pos);
1866    // FindNodeId() returns nullptr for nodes that have been deleted.
1867    // We aren't currently allowing nodes to be deleted, but it is safer
1868    // to still check.
1869    if (node != nullptr) return ToOperation(node);
1870    *pos += 1;
1871  }
1872
1873  // No more nodes.
1874  return nullptr;
1875}
1876
1877void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1878                        TF_Status* status) {
1879  GraphDef def;
1880  {
1881    mutex_lock l(graph->mu);
1882    graph->graph.ToGraphDef(&def);
1883  }
1884  status->status = MessageToBuffer(def, output_graph_def);
1885}
1886
1887void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1888                      TF_Buffer* output_op_def, TF_Status* status) {
1889  const OpDef* op_def;
1890  {
1891    mutex_lock l(graph->mu);
1892    status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1893    if (!status->status.ok()) return;
1894  }
1895  status->status = MessageToBuffer(*op_def, output_op_def);
1896}
1897
1898void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1899                      TF_Status* status) {
1900  VersionDef versions;
1901  {
1902    mutex_lock l(graph->mu);
1903    versions = graph->graph.versions();
1904  }
1905  status->status = MessageToBuffer(versions, output_version_def);
1906}
1907
1908TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1909  return new TF_ImportGraphDefOptions;
1910}
1911void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1912  delete opts;
1913}
1914void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1915                                       const char* prefix) {
1916  opts->opts.prefix = prefix;
1917}
1918
1919void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1920                                              unsigned char uniquify_names) {
1921  opts->opts.uniquify_names = uniquify_names;
1922}
1923
1924void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1925                                               unsigned char uniquify_prefix) {
1926  opts->opts.uniquify_prefix = uniquify_prefix;
1927}
1928
1929void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1930                                             const char* src_name,
1931                                             int src_index, TF_Output dst) {
1932  opts->tensor_id_data.push_back(src_name);
1933  const string& src_name_str = opts->tensor_id_data.back();
1934  // We don't need to store dst's name in tensor_id_data, since `dst` must
1935  // outlive the ImportGraphDef call.
1936  opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1937}
1938
1939void TF_ImportGraphDefOptionsRemapControlDependency(
1940    TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1941  opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1942      TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1943}
1944
1945extern void TF_ImportGraphDefOptionsAddControlDependency(
1946    TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1947  opts->opts.control_dependencies.push_back(oper->node.name());
1948}
1949
1950void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1951                                             const char* oper_name, int index) {
1952  opts->tensor_id_data.push_back(oper_name);
1953  const string& oper_name_str = opts->tensor_id_data.back();
1954  opts->opts.return_tensors.emplace_back(oper_name_str, index);
1955}
1956
1957int TF_ImportGraphDefOptionsNumReturnOutputs(
1958    const TF_ImportGraphDefOptions* opts) {
1959  return opts->opts.return_tensors.size();
1960}
1961
1962void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1963                                                const char* oper_name) {
1964  opts->opts.return_nodes.push_back(oper_name);
1965}
1966
1967int TF_ImportGraphDefOptionsNumReturnOperations(
1968    const TF_ImportGraphDefOptions* opts) {
1969  return opts->opts.return_nodes.size();
1970}
1971
1972void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1973                                           int* num_outputs,
1974                                           TF_Output** outputs) {
1975  *num_outputs = results->return_tensors.size();
1976  *outputs = results->return_tensors.data();
1977}
1978
1979void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1980                                              int* num_opers,
1981                                              TF_Operation*** opers) {
1982  *num_opers = results->return_nodes.size();
1983  *opers = results->return_nodes.data();
1984}
1985
1986void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1987    TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1988    const char*** src_names, int** src_indexes) {
1989  *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1990  *src_names = results->missing_unused_key_names.data();
1991  *src_indexes = results->missing_unused_key_indexes.data();
1992}
1993
1994void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1995  delete results;
1996}
1997
1998static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1999                                      const TF_ImportGraphDefOptions* opts,
2000                                      TF_ImportGraphDefResults* tf_results,
2001                                      TF_Status* status)
2002    EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2003  const int last_node_id = graph->graph.num_node_ids();
2004  tensorflow::ImportGraphDefResults results;
2005  status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2006                                              &graph->refiner, &results);
2007  if (!status->status.ok()) return;
2008
2009  // Add new nodes to name_map
2010  for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2011    auto* node = graph->graph.FindNodeId(i);
2012    if (node != nullptr) graph->name_map[node->name()] = node;
2013  }
2014
2015  // Populate return_tensors
2016  DCHECK(tf_results->return_tensors.empty());
2017  tf_results->return_tensors.resize(results.return_tensors.size());
2018  for (int i = 0; i < results.return_tensors.size(); ++i) {
2019    tf_results->return_tensors[i].oper =
2020        ToOperation(results.return_tensors[i].first);
2021    tf_results->return_tensors[i].index = results.return_tensors[i].second;
2022  }
2023
2024  // Populate return_nodes
2025  DCHECK(tf_results->return_nodes.empty());
2026  tf_results->return_nodes.resize(results.return_nodes.size());
2027  for (int i = 0; i < results.return_nodes.size(); ++i) {
2028    tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2029  }
2030
2031  // Populate missing unused map keys
2032  DCHECK(tf_results->missing_unused_key_names.empty());
2033  DCHECK(tf_results->missing_unused_key_indexes.empty());
2034  DCHECK(tf_results->missing_unused_key_names_data.empty());
2035
2036  size_t size = results.missing_unused_input_map_keys.size();
2037  tf_results->missing_unused_key_names.resize(size);
2038  tf_results->missing_unused_key_indexes.resize(size);
2039
2040  for (int i = 0; i < size; ++i) {
2041    TensorId id = results.missing_unused_input_map_keys[i];
2042    tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
2043    tf_results->missing_unused_key_names[i] =
2044        tf_results->missing_unused_key_names_data.back().c_str();
2045    tf_results->missing_unused_key_indexes[i] = id.second;
2046  }
2047}
2048
2049TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2050    TF_Graph* graph, const TF_Buffer* graph_def,
2051    const TF_ImportGraphDefOptions* options, TF_Status* status) {
2052  GraphDef def;
2053  if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2054    status->status = InvalidArgument("Invalid GraphDef");
2055    return nullptr;
2056  }
2057  auto results = new TF_ImportGraphDefResults();
2058  mutex_lock l(graph->mu);
2059  GraphImportGraphDefLocked(graph, def, options, results, status);
2060  if (!status->status.ok()) {
2061    delete results;
2062    return nullptr;
2063  }
2064  return results;
2065}
2066
2067void TF_GraphImportGraphDefWithReturnOutputs(
2068    TF_Graph* graph, const TF_Buffer* graph_def,
2069    const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2070    int num_return_outputs, TF_Status* status) {
2071  if (num_return_outputs != options->opts.return_tensors.size()) {
2072    status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2073                                     options->opts.return_tensors.size(),
2074                                     ", got ", num_return_outputs);
2075    return;
2076  }
2077  if (num_return_outputs > 0 && return_outputs == nullptr) {
2078    status->status = InvalidArgument(
2079        "'return_outputs' must be preallocated to length ", num_return_outputs);
2080    return;
2081  }
2082  GraphDef def;
2083  if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2084    status->status = InvalidArgument("Invalid GraphDef");
2085    return;
2086  }
2087  TF_ImportGraphDefResults results;
2088  mutex_lock l(graph->mu);
2089  GraphImportGraphDefLocked(graph, def, options, &results, status);
2090  DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2091  memcpy(return_outputs, results.return_tensors.data(),
2092         num_return_outputs * sizeof(TF_Output));
2093}
2094
2095void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2096                            const TF_ImportGraphDefOptions* options,
2097                            TF_Status* status) {
2098  TF_ImportGraphDefResults* results =
2099      TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2100  TF_DeleteImportGraphDefResults(results);
2101}
2102
2103// While loop functions -------------------------------------------------------
2104
2105namespace {
2106
2107#ifndef __ANDROID__
2108
2109// Creates a placeholder representing an input to the cond or body graph.
2110// TODO(skyewm): remove these from final graph
2111bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2112                 TF_Output* input, TF_Status* status) {
2113  TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2114  TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2115  // TODO(skyewm): set placeholder shape
2116  TF_Operation* oper = TF_FinishOperation(desc, status);
2117  if (!status->status.ok()) return false;
2118  *input = {oper, 0};
2119  return true;
2120}
2121
2122// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2123// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
2124// will be prepended to copied node names. `control_deps` are nodes in
2125// `dst_graph` that the copied `src_graph` nodes will have control dependencies
2126// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2127// in `dst_graph` will be returned. `return_nodes` must be non-null.
2128Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2129                 tensorflow::ShapeRefiner* dst_refiner,
2130                 const TF_Output* src_inputs,
2131                 const std::vector<tensorflow::Output>& dst_inputs,
2132                 const string& prefix,
2133                 const std::vector<tensorflow::Operation>& control_deps,
2134                 const TF_Output* nodes_to_return, int nreturn_nodes,
2135                 std::vector<tensorflow::Output>* return_nodes) {
2136  DCHECK(return_nodes != nullptr);
2137  GraphDef gdef;
2138  src_graph->ToGraphDef(&gdef);
2139
2140  tensorflow::ImportGraphDefOptions opts;
2141  opts.prefix = prefix;
2142
2143  for (int i = 0; i < dst_inputs.size(); ++i) {
2144    opts.input_map[ToTensorId(src_inputs[i])] =
2145        TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2146  }
2147  opts.skip_mapped_nodes = true;
2148
2149  for (const tensorflow::Operation& op : control_deps) {
2150    opts.control_dependencies.push_back(op.node()->name());
2151  }
2152
2153  for (int i = 0; i < nreturn_nodes; ++i) {
2154    opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2155  }
2156
2157  // TODO(skyewm): change to OutputTensor
2158  tensorflow::ImportGraphDefResults results;
2159  TF_RETURN_IF_ERROR(
2160      ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2161
2162  for (const auto& pair : results.return_tensors) {
2163    return_nodes->emplace_back(pair.first, pair.second);
2164  }
2165  return Status::OK();
2166}
2167
2168bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2169  if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2170      params.cond_graph->parent == nullptr ||
2171      params.cond_graph->parent != params.body_graph->parent ||
2172      params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2173      params.ninputs <= 0 || params.cond_inputs == nullptr ||
2174      params.body_inputs == nullptr || params.body_outputs == nullptr) {
2175    s->status = InvalidArgument(
2176        "TF_WhileParams must be created by successful TF_NewWhile() call");
2177    return false;
2178  }
2179  return true;
2180}
2181
2182bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2183  if (params.cond_output.oper == nullptr) {
2184    s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2185    return false;
2186  }
2187  for (int i = 0; i < params.ninputs; ++i) {
2188    if (params.body_outputs[i].oper == nullptr) {
2189      s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2190                                  "field isn't set");
2191      return false;
2192    }
2193  }
2194  if (params.name == nullptr) {
2195    s->status = InvalidArgument("TF_WhileParams `name` field is null");
2196    return false;
2197  }
2198  return true;
2199}
2200
2201#endif  // __ANDROID__
2202
2203void FreeWhileResources(const TF_WhileParams* params) {
2204  TF_DeleteGraph(params->cond_graph);
2205  TF_DeleteGraph(params->body_graph);
2206  delete[] params->cond_inputs;
2207  delete[] params->body_inputs;
2208  delete[] params->body_outputs;
2209}
2210
2211TF_WhileParams EmptyWhileParams() {
2212  return {0,       nullptr, nullptr, {nullptr, 0},
2213          nullptr, nullptr, nullptr, nullptr};
2214}
2215
2216}  // namespace
2217
2218TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2219                           TF_Status* status) {
2220#ifdef __ANDROID__
2221  status->status = tensorflow::errors::Unimplemented(
2222      "Creating while loops is not supported in Android. File a bug at "
2223      "https://github.com/tensorflow/tensorflow/issues if this feature is "
2224      "important to you");
2225  return EmptyWhileParams();
2226#else
2227  if (ninputs == 0) {
2228    status->status =
2229        InvalidArgument("TF_NewWhile() must be passed at least one input");
2230    return EmptyWhileParams();
2231  }
2232
2233  TF_Graph* cond_graph = TF_NewGraph();
2234  TF_Graph* body_graph = TF_NewGraph();
2235  cond_graph->parent = g;
2236  cond_graph->parent_inputs = inputs;
2237  body_graph->parent = g;
2238  body_graph->parent_inputs = inputs;
2239
2240  TF_Output* cond_inputs = new TF_Output[ninputs];
2241  TF_Output cond_output = {nullptr, -1};
2242  TF_Output* body_inputs = new TF_Output[ninputs];
2243  TF_Output* body_outputs = new TF_Output[ninputs];
2244  for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2245  const char* name = nullptr;
2246
2247  for (int i = 0; i < ninputs; ++i) {
2248    // TODO(skyewm): prefix names with underscore (requires some plumbing)
2249    if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2250                     &cond_inputs[i], status)) {
2251      break;
2252    }
2253    if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2254                     &body_inputs[i], status)) {
2255      break;
2256    }
2257  }
2258
2259  TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
2260                           body_graph, body_inputs, body_outputs, name};
2261
2262  if (!status->status.ok()) {
2263    FreeWhileResources(&params);
2264    return EmptyWhileParams();
2265  }
2266  return params;
2267#endif  // __ANDROID__
2268}
2269
2270#ifndef __ANDROID__
2271namespace {
2272
2273// TODO(skyewm): make nodes in while loop unfetchable like in Python version
2274void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2275                          TF_Output* outputs) {
2276  if (!ValidateInputWhileParams(*params, status)) return;
2277
2278  TF_Graph* parent = params->cond_graph->parent;
2279  TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2280  int num_loop_vars = params->ninputs;
2281
2282  mutex_lock l(parent->mu);
2283
2284  // 'cond_fn' copies the cond graph into the parent graph.
2285  tensorflow::ops::CondGraphBuilderFn cond_fn =
2286      [params, parent](const tensorflow::Scope& scope,
2287                       const std::vector<tensorflow::Output>& inputs,
2288                       tensorflow::Output* output) {
2289        DCHECK_EQ(scope.graph(), &parent->graph);
2290        std::vector<tensorflow::Output> cond_output;
2291        TF_RETURN_IF_ERROR(CopyGraph(
2292            &params->cond_graph->graph, &parent->graph, &parent->refiner,
2293            params->cond_inputs, inputs, scope.impl()->name(),
2294            scope.impl()->control_deps(), &params->cond_output,
2295            /* nreturn_nodes */ 1, &cond_output));
2296        *output = cond_output[0];
2297        return Status::OK();
2298      };
2299
2300  // 'body_fn' copies the body graph into the parent graph.
2301  tensorflow::ops::BodyGraphBuilderFn body_fn =
2302      [params, parent, num_loop_vars](
2303          const tensorflow::Scope& scope,
2304          const std::vector<tensorflow::Output>& inputs,
2305          std::vector<tensorflow::Output>* outputs) {
2306        DCHECK_EQ(scope.graph(), &parent->graph);
2307        TF_RETURN_IF_ERROR(
2308            CopyGraph(&params->body_graph->graph, &parent->graph,
2309                      &parent->refiner, params->body_inputs, inputs,
2310                      scope.impl()->name(), scope.impl()->control_deps(),
2311                      params->body_outputs, num_loop_vars, outputs));
2312        return Status::OK();
2313      };
2314
2315  // Create the while loop using an internal scope.
2316  tensorflow::Scope scope =
2317      NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2318          .NewSubScope(params->name);
2319
2320  const int first_new_node_id = parent->graph.num_node_ids();
2321
2322  tensorflow::OutputList loop_outputs;
2323  status->status = tensorflow::ops::BuildWhileLoop(
2324      scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2325      body_fn, params->name, &loop_outputs);
2326
2327  // Update name_map with newly-created ops.
2328  // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2329  // a bad status. Once we fix this, we may want to return early instead of
2330  // executing the following code.
2331  for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2332    Node* new_node = parent->graph.FindNodeId(i);
2333    if (new_node == nullptr) continue;
2334    parent->name_map[new_node->name()] = new_node;
2335  }
2336
2337  // Populate 'outputs'.
2338  DCHECK_LE(loop_outputs.size(), num_loop_vars);
2339  for (int i = 0; i < loop_outputs.size(); ++i) {
2340    outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2341  }
2342}
2343
2344}  // namespace
2345#endif  // __ANDROID__
2346
2347void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2348                    TF_Output* outputs) {
2349#ifdef __ANDROID__
2350  status->status = tensorflow::errors::Unimplemented(
2351      "Creating while loops is not supported in Android. File a bug at "
2352      "https://github.com/tensorflow/tensorflow/issues if this feature is "
2353      "important to you");
2354#else
2355  // If it appears the caller created or modified `params`, don't free resources
2356  if (!ValidateConstWhileParams(*params, status)) return;
2357  TF_FinishWhileHelper(params, status, outputs);
2358  FreeWhileResources(params);
2359#endif  // __ANDROID__
2360}
2361
2362void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2363
2364void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2365                     TF_Output* dx, TF_Status* status, TF_Output* dy) {
2366#ifdef __ANDROID__
2367  status->status = tensorflow::errors::Unimplemented(
2368      "Adding gradients is not supported in Android. File a bug at "
2369      "https://github.com/tensorflow/tensorflow/issues if this feature is "
2370      "important to you");
2371#else
2372  std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2373  std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2374  std::vector<tensorflow::Output> dy_arg;
2375
2376  {
2377    // We need to hold on to the lock while we have a scope that uses TF_Graph.
2378    mutex_lock graph_lock(g->mu);
2379
2380    const int first_new_node_id = g->graph.num_node_ids();
2381
2382    tensorflow::Scope scope =
2383        NewInternalScope(&g->graph, &status->status, &g->refiner)
2384            .NewSubScope("gradients");
2385
2386    if (dx != nullptr) {
2387      std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2388      status->status =
2389          AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2390    } else {
2391      status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2392    }
2393
2394    // Update g->name_map with the name_map from the scope, which will contain
2395    // the new gradient ops.
2396    for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2397      Node* n = g->graph.FindNodeId(i);
2398      if (n == nullptr) continue;
2399      g->name_map[n->name()] = n;
2400    }
2401  }
2402
2403  // Unpack the results from grad_outputs_arg.
2404  TFOutputsFromOutputs(dy_arg, dy);
2405#endif  // __ANDROID__
2406}
2407
2408// TF_Session functions ----------------------------------------------
2409
2410TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2411    : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) {
2412  if (s->LocalDeviceManager(&device_mgr).ok()) {
2413    devices = device_mgr->ListDevices();
2414  }
2415}
2416
2417TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2418                          TF_Status* status) {
2419  Session* session;
2420  status->status = NewSession(opt->options, &session);
2421  if (status->status.ok()) {
2422    TF_Session* new_session = new TF_Session(session, graph);
2423    if (graph != nullptr) {
2424      mutex_lock l(graph->mu);
2425      graph->sessions[new_session] = Status::OK();
2426    }
2427    return new_session;
2428  } else {
2429    DCHECK_EQ(nullptr, session);
2430    return nullptr;
2431  }
2432}
2433
2434TF_Session* TF_LoadSessionFromSavedModel(
2435    const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2436    const char* export_dir, const char* const* tags, int tags_len,
2437    TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2438// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
2439// the tensorflow/cc/saved_model:loader build target is Android friendly.
2440#ifdef __ANDROID__
2441  status->status = tensorflow::errors::Unimplemented(
2442      "Loading a SavedModel is not supported in Android. File a bug at "
2443      "https://github.com/tensorflow/tensorflow/issues if this feature is "
2444      "important to you");
2445  return nullptr;
2446#else
2447  mutex_lock l(graph->mu);
2448  if (!graph->name_map.empty()) {
2449    status->status = InvalidArgument("Graph is non-empty.");
2450    return nullptr;
2451  }
2452
2453  RunOptions run_options_proto;
2454  if (run_options != nullptr && !run_options_proto.ParseFromArray(
2455                                    run_options->data, run_options->length)) {
2456    status->status = InvalidArgument("Unparseable RunOptions proto");
2457    return nullptr;
2458  }
2459
2460  std::unordered_set<string> tag_set;
2461  for (int i = 0; i < tags_len; i++) {
2462    tag_set.insert(string(tags[i]));
2463  }
2464
2465  tensorflow::SavedModelBundle bundle;
2466  status->status =
2467      tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2468                                 export_dir, tag_set, &bundle);
2469  if (!status->status.ok()) return nullptr;
2470
2471  // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2472  // extends using GraphDefs. The Graph instance is different, but equivalent
2473  // to the one used to create the session.
2474  //
2475  // TODO(jhseu): When Session is modified to take Graphs instead of
2476  // GraphDefs, return the Graph generated in LoadSavedModel().
2477  TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2478  TF_ImportGraphDefResults results;
2479  GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2480                            import_opts, &results, status);
2481  TF_DeleteImportGraphDefOptions(import_opts);
2482  if (TF_GetCode(status) != TF_OK) return nullptr;
2483
2484  if (meta_graph_def != nullptr) {
2485    status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2486    if (!status->status.ok()) return nullptr;
2487  }
2488
2489  TF_Session* session = new TF_Session(bundle.session.release(), graph);
2490
2491  graph->sessions[session] = Status::OK();
2492  session->last_num_graph_nodes = graph->graph.num_node_ids();
2493  return session;
2494#endif  // __ANDROID__
2495}
2496
2497void TF_CloseSession(TF_Session* s, TF_Status* status) {
2498  status->status = s->session->Close();
2499}
2500
2501void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2502  status->status = Status::OK();
2503  TF_Graph* const graph = s->graph;
2504  if (graph != nullptr) {
2505    graph->mu.lock();
2506    graph->sessions.erase(s);
2507    const bool del = graph->delete_requested && graph->sessions.empty();
2508    graph->mu.unlock();
2509    if (del) delete graph;
2510  }
2511  delete s->session;
2512  delete s;
2513}
2514
2515// TODO(josh11b,mrry): Change Session to be able to use a Graph*
2516// directly, instead of requiring us to serialize to a GraphDef and
2517// call Session::Extend().
2518static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
2519  if (session->graph != nullptr) {
2520    mutex_lock session_lock(session->mu);
2521    session->graph->mu.lock();
2522    const Graph& graph = session->graph->graph;
2523
2524    status->status = session->graph->sessions[session];
2525    if (!status->status.ok()) {
2526      session->graph->mu.unlock();
2527      return false;
2528    }
2529
2530    const auto num_nodes = graph.num_node_ids();
2531    if (session->last_num_graph_nodes < num_nodes) {
2532      status->status = tensorflow::ValidateNoCycles(session->graph->graph);
2533      if (!status->status.ok()) {
2534        session->graph->mu.unlock();
2535        return false;
2536      }
2537
2538      GraphDef graph_def;
2539      *graph_def.mutable_versions() = graph.versions();
2540      // Fill graph_def with nodes with ids in the range
2541      // [session->last_num_graph_nodes, num_nodes), that is the nodes
2542      // added since the last TF_SessionRun() call.
2543      for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
2544        Node* const node = graph.FindNodeId(id);
2545        if (node != nullptr && node->IsOp()) {
2546          NodeDef* const node_def = graph_def.add_node();
2547          *node_def = node->def();
2548        }
2549      }
2550      *graph_def.mutable_library() = graph.flib_def().ToProto();
2551      session->graph->mu.unlock();
2552      status->status = session->session->Extend(graph_def);
2553      if (!status->status.ok()) {
2554        // Contract is we always delete input_values[i].
2555        return false;
2556      }
2557      // Note: session->session is not modified if Extend() fails, so
2558      // we only set last_num_graph_nodes if it succeeds.
2559      session->last_num_graph_nodes = num_nodes;
2560    } else {
2561      session->graph->mu.unlock();
2562    }
2563  }
2564  return true;
2565}
2566
2567void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2568                   const TF_Output* inputs, TF_Tensor* const* input_values,
2569                   int ninputs, const TF_Output* outputs,
2570                   TF_Tensor** output_values, int noutputs,
2571                   const TF_Operation* const* target_opers, int ntargets,
2572                   TF_Buffer* run_metadata, TF_Status* status) {
2573  // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2574  // directly, instead of requiring us to serialize to a GraphDef and
2575  // call Session::Extend().
2576  if (!ExtendSessionGraphHelper(session, status)) {
2577    return;
2578  }
2579
2580  TF_Run_Setup(noutputs, output_values, status);
2581
2582  // Convert from TF_Output and TF_Tensor to a string and Tensor.
2583  std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2584  if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2585  for (int i = 0; i < ninputs; ++i) {
2586    input_pairs[i].first = OutputName(inputs[i]);
2587  }
2588
2589  // Convert from TF_Output to string names.
2590  std::vector<string> output_names(noutputs);
2591  for (int i = 0; i < noutputs; ++i) {
2592    output_names[i] = OutputName(outputs[i]);
2593  }
2594
2595  // Convert from TF_Operation* to string names.
2596  std::vector<string> target_names(ntargets);
2597  for (int i = 0; i < ntargets; ++i) {
2598    target_names[i] = target_opers[i]->node.name();
2599  }
2600
2601  // Actually run.
2602  TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2603                output_names, output_values, target_names, run_metadata,
2604                status);
2605}
2606
2607void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2608                         int ninputs, const TF_Output* outputs, int noutputs,
2609                         const TF_Operation* const* target_opers, int ntargets,
2610                         const char** handle, TF_Status* status) {
2611  *handle = nullptr;
2612
2613  if (!ExtendSessionGraphHelper(session, status)) {
2614    return;
2615  }
2616
2617  std::vector<string> input_names(ninputs);
2618  for (int i = 0; i < ninputs; ++i) {
2619    input_names[i] = OutputName(inputs[i]);
2620  }
2621
2622  std::vector<string> output_names(noutputs);
2623  for (int i = 0; i < noutputs; ++i) {
2624    output_names[i] = OutputName(outputs[i]);
2625  }
2626
2627  std::vector<string> target_names(ntargets);
2628  for (int i = 0; i < ntargets; ++i) {
2629    target_names[i] = target_opers[i]->node.name();
2630  }
2631
2632  string new_handle;
2633  status->status = session->session->PRunSetup(input_names, output_names,
2634                                               target_names, &new_handle);
2635  if (status->status.ok()) {
2636    char* buf = new char[new_handle.size() + 1];
2637    memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2638    *handle = buf;
2639  }
2640}
2641
2642void TF_DeletePRunHandle(const char* handle) {
2643  delete[] handle;
2644  // TODO(suharshs): Free up any resources held by the partial run state.
2645}
2646
2647void TF_SessionPRun(TF_Session* session, const char* handle,
2648                    const TF_Output* inputs, TF_Tensor* const* input_values,
2649                    int ninputs, const TF_Output* outputs,
2650                    TF_Tensor** output_values, int noutputs,
2651                    const TF_Operation* const* target_opers, int ntargets,
2652                    TF_Status* status) {
2653  // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2654  // directly, instead of requiring us to serialize to a GraphDef and
2655  // call Session::Extend().
2656  if (!ExtendSessionGraphHelper(session, status)) {
2657    return;
2658  }
2659
2660  TF_Run_Setup(noutputs, output_values, status);
2661
2662  // Convert from TF_Output and TF_Tensor to a string and Tensor.
2663  std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2664  if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2665  for (int i = 0; i < ninputs; ++i) {
2666    input_pairs[i].first = OutputName(inputs[i]);
2667  }
2668
2669  // Convert from TF_Output to string names.
2670  std::vector<string> output_names(noutputs);
2671  for (int i = 0; i < noutputs; ++i) {
2672    output_names[i] = OutputName(outputs[i]);
2673  }
2674
2675  // Convert from TF_Operation* to string names.
2676  std::vector<string> target_names(ntargets);
2677  for (int i = 0; i < ntargets; ++i) {
2678    target_names[i] = target_opers[i]->node.name();
2679  }
2680
2681  TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2682                output_values, target_names, nullptr, status);
2683}
2684
2685TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2686  tensorflow::OpList op_list;
2687  if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2688    status->status = InvalidArgument("Unparseable OpList");
2689    return nullptr;
2690  }
2691  status->status = Status::OK();
2692  return new TF_ApiDefMap(op_list);
2693}
2694
2695void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2696
2697void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2698                     size_t text_len, TF_Status* status) {
2699#ifdef __ANDROID__
2700  status->status = tensorflow::errors::Unimplemented(
2701      "ApiDefMap is not supported in Android.");
2702#else
2703  mutex_lock l(api_def_map->lock);
2704  if (api_def_map->update_docs_called) {
2705    status->status = FailedPrecondition(
2706        "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2707        "called.");
2708    return;
2709  }
2710  string api_def_text(text, text_len);
2711  status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2712#endif  // __ANDROID__
2713}
2714
2715TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2716                           size_t name_len, TF_Status* status) {
2717#ifdef __ANDROID__
2718  status->status = tensorflow::errors::Unimplemented(
2719      "ApiDefMap is not supported in Android.");
2720  return nullptr;
2721#else
2722  mutex_lock l(api_def_map->lock);
2723  if (!api_def_map->update_docs_called) {
2724    api_def_map->api_def_map.UpdateDocs();
2725    api_def_map->update_docs_called = true;
2726  }
2727  string name_str(name, name_len);
2728  const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2729
2730  TF_Buffer* ret = TF_NewBuffer();
2731  status->status = MessageToBuffer(*api_def, ret);
2732  return ret;
2733#endif  // __ANDROID__
2734}
2735}  // end extern "C"
2736