1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16788f359b7218ad46696c15459c89688ffe70955eA. Unique TensorFlower#include "tensorflow/c/c_api.h" 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1834501544c48061bdfb6a0e7cddf1f70136cf3040Asim Shankar#include <algorithm> 19786938758a39940ae0154834bfed9e21894afa28Asim Shankar#include <limits> 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory> 21b481783fe0e00a86f6feb20a8dcad5fc4fc936a4Josh Levenberg#include <vector> 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 23e8f2aad0c0502fde74fc629f5b13f04d5d206700Asim Shankar#ifndef __ANDROID__ 24908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar#include "tensorflow/cc/framework/gradients.h" 25908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar#include "tensorflow/cc/framework/ops.h" 26908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar#include "tensorflow/cc/framework/scope_internal.h" 270fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#include "tensorflow/cc/ops/while_loop.h" 2835ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu#include "tensorflow/cc/saved_model/loader.h" 29d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#include "tensorflow/core/framework/op_gen_lib.h" 30e8f2aad0c0502fde74fc629f5b13f04d5d206700Asim Shankar#endif 3179d6be3b86ba32048cac6afd5d1c5d1bb8aee6d9Alexandre Passos#include "tensorflow/c/c_api_internal.h" 3222651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar#include "tensorflow/core/common_runtime/device_mgr.h" 3387944a798a13832f4b88a716d49a73ac37ccb9a3Vijay Vasudevan#include "tensorflow/core/common_runtime/shape_refiner.h" 34e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/allocation_description.pb.h" 35ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower#include "tensorflow/core/framework/log_memory.h" 3695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar#include "tensorflow/core/framework/node_def_util.h" 37ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h" 38f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower#include "tensorflow/core/framework/partial_tensor_shape.h" 39b2f0bc2e230dcd690e7cf34e5425f0f499d9557bJosh Levenberg#include "tensorflow/core/framework/tensor.h" 40b2f0bc2e230dcd690e7cf34e5425f0f499d9557bJosh Levenberg#include "tensorflow/core/framework/tensor_shape.h" 41e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor_shape.pb.h" 420b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar#include "tensorflow/core/framework/types.h" 43e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/versions.pb.h" 44f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower#include "tensorflow/core/graph/graph.h" 4522221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar#include "tensorflow/core/graph/graph_constructor.h" 46f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower#include "tensorflow/core/graph/node_builder.h" 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/coding.h" 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/errors.h" 49b2f0bc2e230dcd690e7cf34e5425f0f499d9557bJosh Levenberg#include "tensorflow/core/lib/core/status.h" 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/stringpiece.h" 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h" 52f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower#include "tensorflow/core/lib/strings/strcat.h" 5383c6e0c63acdcab2c58c4ed7220bfa58879b1d57Jonathan Hseu#include "tensorflow/core/platform/mem.h" 54f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower#include "tensorflow/core/platform/mutex.h" 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/protobuf.h" 562677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan#include "tensorflow/core/platform/thread_annotations.h" 57b2f0bc2e230dcd690e7cf34e5425f0f499d9557bJosh Levenberg#include "tensorflow/core/platform/types.h" 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/session.h" 59a374ea13c0c7b9598b5ada851b43655f895a578eAsim Shankar#include "tensorflow/core/public/version.h" 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// The implementation below is at the top level instead of the 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// brain namespace because we are defining 'extern "C"' functions. 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::AllocationDescription; 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::DataType; 65f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerusing tensorflow::Graph; 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::GraphDef; 67caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankarusing tensorflow::mutex_lock; 68078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlowerusing tensorflow::NameRangeMap; 69078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlowerusing tensorflow::NameRangesForNode; 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::NewSession; 71f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerusing tensorflow::Node; 72f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerusing tensorflow::NodeBuilder; 734c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarusing tensorflow::NodeDef; 74078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlowerusing tensorflow::OpDef; 75f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerusing tensorflow::OpRegistry; 76f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerusing tensorflow::PartialTensorShape; 77b0fc9af3818e3d5385edbfa9e38d66c3000ecf71A. Unique TensorFlowerusing tensorflow::RunMetadata; 78b0fc9af3818e3d5385edbfa9e38d66c3000ecf71A. Unique TensorFlowerusing tensorflow::RunOptions; 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::Session; 80b0fc9af3818e3d5385edbfa9e38d66c3000ecf71A. Unique TensorFlowerusing tensorflow::Status; 81caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankarusing tensorflow::string; 82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::Tensor; 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::TensorBuffer; 84bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milneusing tensorflow::TensorId; 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurusing tensorflow::TensorShape; 86f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerusing tensorflow::TensorShapeProto; 877fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milneusing tensorflow::VersionDef; 884c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarusing tensorflow::error::Code; 8922651083406ca01ac9d481e3367a3510d25f88cdAsim Shankarusing tensorflow::errors::FailedPrecondition; 904c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarusing tensorflow::errors::InvalidArgument; 914c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarusing tensorflow::gtl::ArraySlice; 924c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarusing tensorflow::strings::StrCat; 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurextern "C" { 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// -------------------------------------------------------------------------- 97a374ea13c0c7b9598b5ada851b43655f895a578eAsim Shankarconst char* TF_Version() { return TF_VERSION_STRING; } 98a374ea13c0c7b9598b5ada851b43655f895a578eAsim Shankar 99a374ea13c0c7b9598b5ada851b43655f895a578eAsim Shankar// -------------------------------------------------------------------------- 100bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09Jonathan Hseusize_t TF_DataTypeSize(TF_DataType dt) { 101bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09Jonathan Hseu return static_cast<size_t>( 102bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09Jonathan Hseu tensorflow::DataTypeSize(static_cast<DataType>(dt))); 103bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09Jonathan Hseu} 104bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09Jonathan Hseu 105bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09Jonathan Hseu// -------------------------------------------------------------------------- 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_Status* TF_NewStatus() { return new TF_Status; } 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurvoid TF_DeleteStatus(TF_Status* s) { delete s; } 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurvoid TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { 112ee5a6641e3a2db7807a655bef21614e97b82c769Akshay Modi if (code == TF_OK) { 113ee5a6641e3a2db7807a655bef21614e97b82c769Akshay Modi s->status = Status::OK(); 114ee5a6641e3a2db7807a655bef21614e97b82c769Akshay Modi return; 115ee5a6641e3a2db7807a655bef21614e97b82c769Akshay Modi } 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg)); 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_Code TF_GetCode(const TF_Status* s) { 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return static_cast<TF_Code>(s->status.code()); 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurconst char* TF_Message(const TF_Status* s) { 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return s->status.error_message().c_str(); 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// -------------------------------------------------------------------------- 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace { 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass TF_ManagedBuffer : public TensorBuffer { 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void* data_; 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur size_t len_; 134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void (*deallocator_)(void* data, size_t len, void* arg); 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void* deallocator_arg_; 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~TF_ManagedBuffer() override { 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur (*deallocator_)(data_, len_, deallocator_arg_); 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void* data() const override { return data_; } 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur size_t size() const override { return len_; } 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorBuffer* root_buffer() override { return this; } 144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void FillAllocationDescription(AllocationDescription* proto) const override { 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensorflow::int64 rb = size(); 146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur proto->set_requested_bytes(rb); 147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 149d9b9fa4ffe85be28ebd8a3491e967b56f7f1bb87Alexandre Passos 150d9b9fa4ffe85be28ebd8a3491e967b56f7f1bb87Alexandre Passos // Prevents input forwarding from mutating this buffer. 151d9b9fa4ffe85be28ebd8a3491e967b56f7f1bb87Alexandre Passos bool OwnsMemory() const override { return false; } 152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 154ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseuvoid* allocate_tensor(const char* operation, size_t len) { 155ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu void* data = 156ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); 15728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if (tensorflow::LogMemory::IsEnabled() && data != nullptr) { 158ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu tensorflow::LogMemory::RecordRawAllocation( 159ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, 160ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu len, data, tensorflow::cpu_allocator()); 161ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu } 162ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu return data; 163ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu} 164ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu 165ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseuvoid deallocate_buffer(void* data, size_t len, void* arg) { 16628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if (tensorflow::LogMemory::IsEnabled() && data != nullptr) { 167ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower tensorflow::LogMemory::RecordRawDeallocation( 168ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower "TensorFlow C Api", 169ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data, 170ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower tensorflow::cpu_allocator(), false); 171ec1403e7dc2b919531e527d36d28659f60621c9eA. Unique TensorFlower } 172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensorflow::cpu_allocator()->DeallocateRaw(data); 173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 17495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace 176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1772bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim ShankarTF_Tensor::~TF_Tensor() { buffer->Unref(); } 17822651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar 179ae7b1310c5b2bbb333191d0def7985202dee382aJonathan HseuTF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, 180ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu int num_dims, size_t len) { 181ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu void* data = allocate_tensor("TF_AllocateTensor", len); 182ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer, 183ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu nullptr); 184ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu} 185ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu 186f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerTF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, 187f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower void* data, size_t len, 188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void (*deallocator)(void* data, size_t len, void* arg), 189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void* deallocator_arg) { 190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::vector<tensorflow::int64> dimvec(num_dims); 1917cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < num_dims; ++i) { 192f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower dimvec[i] = static_cast<tensorflow::int64>(dims[i]); 193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TF_ManagedBuffer* buf = new TF_ManagedBuffer; 196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur buf->len_ = len; 1970b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar if (dtype != TF_STRING && dtype != TF_RESOURCE && 1980b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) && 1990b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) { 2000b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar // TF_STRING and TF_RESOURCE tensors have a different representation in 2010b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste 202caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar // (any alignment requirements will be taken care of by TF_TensorToTensor 2030b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar // and TF_TensorFromTensor). 2040b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar // 205caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar // Other types have the same representation, so copy only if it is safe to 206caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar // do so. 207ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu buf->data_ = allocate_tensor("TF_NewTensor", len); 208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::memcpy(buf->data_, data, len); 209ae7b1310c5b2bbb333191d0def7985202dee382aJonathan Hseu buf->deallocator_ = deallocate_buffer; 210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur buf->deallocator_arg_ = nullptr; 211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Free the original buffer. 212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur deallocator(data, len, deallocator_arg); 213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } else { 214f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur buf->data_ = data; 215f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur buf->deallocator_ = deallocator; 216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur buf->deallocator_arg_ = deallocator_arg; 217f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 218caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; 219caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar size_t elem_size = TF_DataTypeSize(dtype); 220caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { 221caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar delete ret; 222caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar return nullptr; 223caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar } 224caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar return ret; 225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 227a5a8558feb9417359e30a991ab5e01cf17194473Alexandre PassosTF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { 228a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos // It is safe to move the Tensor if and only if we own the unique reference to 229a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos // it. In that case, we might as well not delete and reallocate, but a future 230a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos // implementation might need to do so. 2312bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar TensorBuffer* buf = tensor->buffer; 23222651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() && 23322651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar buf->OwnsMemory()) { 234a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos return tensor; 235a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos } 236a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos return nullptr; 237a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos} 238a5a8558feb9417359e30a991ab5e01cf17194473Alexandre Passos 23922651083406ca01ac9d481e3367a3510d25f88cdAsim Shankarvoid TF_DeleteTensor(TF_Tensor* t) { delete t; } 240f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 241f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; } 242f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurint TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); } 243f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerint64_t TF_Dim(const TF_Tensor* t, int dim_index) { 244f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return static_cast<int64_t>(t->shape.dim_size(dim_index)); 245f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 2462bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankarsize_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } 2472bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankarvoid* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } 248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// -------------------------------------------------------------------------- 250cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankarsize_t TF_StringEncode(const char* src, size_t src_len, char* dst, 251cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar size_t dst_len, TF_Status* status) { 252cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar const size_t sz = TF_StringEncodedSize(src_len); 253fd398fdb1a68a823c56836184197fe3aaaf09a5aAsim Shankar if (sz < src_len) { 254fd398fdb1a68a823c56836184197fe3aaaf09a5aAsim Shankar status->status = InvalidArgument("src string is too large to encode"); 255fd398fdb1a68a823c56836184197fe3aaaf09a5aAsim Shankar return 0; 256fd398fdb1a68a823c56836184197fe3aaaf09a5aAsim Shankar } 257cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar if (dst_len < sz) { 258cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar status->status = 259cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar InvalidArgument("dst_len (", dst_len, ") too small to encode a ", 260cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar src_len, "-byte string"); 261cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar return 0; 262cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar } 263cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar dst = tensorflow::core::EncodeVarint64(dst, src_len); 264cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar memcpy(dst, src, src_len); 265cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar return sz; 266cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar} 267cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar 2684c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarstatic Status TF_StringDecode_Impl(const char* src, size_t src_len, 2694c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar const char** dst, size_t* dst_len) { 270cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar tensorflow::uint64 len64 = 0; 271cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64); 272cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar if (p == nullptr) { 2734c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return InvalidArgument("invalid string encoding or truncated src buffer"); 274cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar } 275cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar if (len64 > std::numeric_limits<size_t>::max()) { 2764c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return InvalidArgument("encoded string is ", len64, 2774c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar "-bytes, which is too large for this architecture"); 278cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar } 279cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar *dst = p; 280cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar *dst_len = static_cast<size_t>(len64); 2814c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return Status::OK(); 2824c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar} 2834c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar 2844c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankarsize_t TF_StringDecode(const char* src, size_t src_len, const char** dst, 2854c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar size_t* dst_len, TF_Status* status) { 2864c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); 2874c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (!status->status.ok()) return 0; 2884c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return static_cast<size_t>(*dst - src) + *dst_len; 289cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar} 290cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar 291cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankarsize_t TF_StringEncodedSize(size_t len) { 292cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len; 293cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar} 294cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar 295cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar// -------------------------------------------------------------------------- 296f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } 297f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurvoid TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } 298f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 299f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurvoid TF_SetTarget(TF_SessionOptions* options, const char* target) { 300f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur options->options.target = target; 301f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 302f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 3034213ac97be449d0e40631a314d2b7bd3901d4967Vijay Vasudevanvoid TF_SetConfig(TF_SessionOptions* options, const void* proto, 3044213ac97be449d0e40631a314d2b7bd3901d4967Vijay Vasudevan size_t proto_len, TF_Status* status) { 3054213ac97be449d0e40631a314d2b7bd3901d4967Vijay Vasudevan if (!options->options.config.ParseFromArray(proto, proto_len)) { 306c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Unparseable ConfigProto"); 307f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 308f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 309b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang// -------------------------------------------------------------------------- 310a465e286ccc9c83cb44f3153e91a8ae937904c49Zongheng YangTF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } 311b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang 312b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng YangTF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { 31383c6e0c63acdcab2c58c4ed7220bfa58879b1d57Jonathan Hseu void* copy = tensorflow::port::Malloc(proto_len); 314b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang memcpy(copy, proto, proto_len); 315b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang 316b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang TF_Buffer* buf = new TF_Buffer; 317b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang buf->data = copy; 318b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang buf->length = proto_len; 31983c6e0c63acdcab2c58c4ed7220bfa58879b1d57Jonathan Hseu buf->data_deallocator = [](void* data, size_t length) { 32083c6e0c63acdcab2c58c4ed7220bfa58879b1d57Jonathan Hseu tensorflow::port::Free(data); 32183c6e0c63acdcab2c58c4ed7220bfa58879b1d57Jonathan Hseu }; 322b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang return buf; 323b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang} 324b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang 325b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yangvoid TF_DeleteBuffer(TF_Buffer* buffer) { 326b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang if (buffer->data_deallocator != nullptr) { 327b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang (*buffer->data_deallocator)(const_cast<void*>(buffer->data), 328b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang buffer->length); 329b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang } 330b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang delete buffer; 331b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang} 332b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang 333b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng YangTF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } 334f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 335f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// -------------------------------------------------------------------------- 336f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 337fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim ShankarTF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, 338fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankar TF_Status* status) { 339f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Session* session; 340f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur status->status = NewSession(opt->options, &session); 341f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (status->status.ok()) { 342fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankar return new TF_DeprecatedSession({session}); 343f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } else { 344f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DCHECK_EQ(nullptr, session); 345d83074847ebfe8871188f1f9f1e84ab0451f59e6A. Unique TensorFlower return nullptr; 346f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 347f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 348f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 349fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankarvoid TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { 350f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur status->status = s->session->Close(); 351f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 352f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 353fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankarvoid TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { 354f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur status->status = Status::OK(); 355f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur delete s->session; 356f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur delete s; 357f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 358f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 359fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankarvoid TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, 360fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankar size_t proto_len, TF_Status* status) { 361f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur GraphDef g; 362f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { 363c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Invalid GraphDef"); 364f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return; 365f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 366f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur status->status = s->session->Extend(g); 367f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 368f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 369f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurstatic void DeleteArray(void* data, size_t size, void* arg) { 370f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DCHECK_EQ(data, arg); 371f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur delete[] reinterpret_cast<char*>(arg); 372f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 373f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 374f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end extern "C" 375f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 376f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 3777083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moorenamespace { 3787083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 3797083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore// Reset helper for converting character arrays to string vectors. 3807083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moorevoid TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, 3817083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore int ncontainers, TF_Status* status) { 382ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> container_names(ncontainers); 3837cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ncontainers; ++i) { 3847083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore container_names[i] = containers[i]; 3857083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore } 3867083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 3877083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore status->status = Reset(opt->options, container_names); 3887083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore} 3897083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 390f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist// This traverses the specified nodes in topological order to verify there are 391f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist// no cycles. Starting with inputless nodes, it visits nodes whose inputs have 392f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist// all been visited, and counts the total number of visited nodes. If there is a 393f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist// cycle, nodes in the cycle will never be visited, and the visited count will 394f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist// be less than the total node count. 395f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia NordquistStatus ValidateNoCycles(const Graph& g) { 396f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist // TODO(nolivia): check this on a subset of the graph instead of all of it. 397f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist // A node is ready when all of its inputs have been visited. 398f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist std::vector<const Node*> ready; 399a80fd2acf08ceba0c8fc7684c3013e8e7d6bd8d3Skye Wanderman-Milne std::vector<int> pending_count(g.num_node_ids(), 0); 400f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist 401a80fd2acf08ceba0c8fc7684c3013e8e7d6bd8d3Skye Wanderman-Milne for (int i = 0; i < g.num_node_ids(); ++i) { 402f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist const Node* n = g.FindNodeId(i); 403f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (n == nullptr) continue; 404f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist pending_count[i] = n->in_edges().size(); 405f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (n->IsMerge()) { 406f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist // While-loop cycles are legal cycles so we manually adjust the 407f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist // pending_count to make sure that the loop is visited. 408f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist for (const Edge* e : n->in_edges()) { 409f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (!e->IsControlEdge() && e->src()->IsNextIteration()) { 410f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist pending_count[i]--; 411f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 412f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 413f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 414f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (pending_count[i] == 0) { 415f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist ready.push_back(n); 416f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 417f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 418f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist 419f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist int processed = 0; 420f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist while (!ready.empty()) { 421f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist const Node* node = ready.back(); 422f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist ready.pop_back(); 423f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist ++processed; 424f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist 425f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist for (const Edge* out : node->out_edges()) { 426f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist const int output_id = out->dst()->id(); 427f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist pending_count[output_id]--; 428f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (pending_count[output_id] == 0) { 429f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist ready.push_back(out->dst()); 430f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 431f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 432f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 433f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist 434a80fd2acf08ceba0c8fc7684c3013e8e7d6bd8d3Skye Wanderman-Milne if (processed < g.num_nodes()) { 435f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist std::vector<string> nodes_in_cycle; 436f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; 437f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist ++i) { 438f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (pending_count[i] != 0) { 439f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist nodes_in_cycle.push_back(g.FindNodeId(i)->name()); 440f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 441f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 442f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist return errors::InvalidArgument( 443a80fd2acf08ceba0c8fc7684c3013e8e7d6bd8d3Skye Wanderman-Milne "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, 444f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); 445f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 446f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist return Status::OK(); 447f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist} 4487083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore} // namespace 4497083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore} // namespace tensorflow 4507083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 4517083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Mooreextern "C" { 4527083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 4537083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moorevoid TF_Reset(const TF_SessionOptions* opt, const char** containers, 4547083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore int ncontainers, TF_Status* status) { 4557083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status); 4567083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore} 4577083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 4587083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore} // end extern "C" 4597083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 4607083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moorenamespace tensorflow { 4617083e2af69f5692d74fc6b1148c0ecea0c562274Sherry Moore 4624c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim ShankarStatus TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { 4630b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar if (src->dtype == TF_RESOURCE) { 4640b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar if (src->shape.dims() != 0) { 4650b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar return InvalidArgument( 4660b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with " 4670b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar "shape ", 4680b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar src->shape.DebugString()); 4690b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar } 4700b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar *dst = Tensor(DT_RESOURCE, src->shape); 4710b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar if (!dst->scalar<ResourceHandle>()().ParseFromString( 4720b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar string(static_cast<const char*>(TF_TensorData(src)), 4730b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar TF_TensorByteSize(src)))) { 4740b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar return InvalidArgument( 4750b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar "Malformed TF_RESOUCE tensor: unable to parse resource handle"); 4760b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar } 4770b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar return Status::OK(); 4780b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar } 4794c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (src->dtype != TF_STRING) { 4802bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer); 4814c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return Status::OK(); 4824c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar } 4834c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar // TF_STRING tensors require copying since Tensor class expects a sequence of 4844c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar // string objects. 485f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const tensorflow::int64 num_elements = src->shape.num_elements(); 486f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const char* input = reinterpret_cast<const char*>(TF_TensorData(src)); 487f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const size_t src_size = TF_TensorByteSize(src); 488f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) < 489f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur num_elements) { 4904c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return InvalidArgument( 491f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur "Malformed TF_STRING tensor; too short to hold number of elements"); 492f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 493f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; 494f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const char* limit = input + src_size; 495f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 496f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur *dst = Tensor(static_cast<DataType>(src->dtype), src->shape); 497ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne auto dstarray = dst->flat<string>(); 4987cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (tensorflow::int64 i = 0; i < num_elements; ++i) { 499f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensorflow::uint64 offset = 500f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur reinterpret_cast<const tensorflow::uint64*>(input)[i]; 501cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) { 5024c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return InvalidArgument("Malformed TF_STRING tensor; element ", i, 5034c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar " out of range"); 504f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 505cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar size_t len; 506cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar const char* p; 507cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar const char* srcp = data_start + offset; 5084c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len); 5094c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (!status.ok()) return status; 510f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur dstarray(i).assign(p, len); 511f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 5124c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return Status::OK(); 513f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 514f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 515b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to 516b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar// result in a zero-sized tensor. 517b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankarstatic TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { 518b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar static char empty; 519b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar tensorflow::int64 nelems = 1; 520b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar std::vector<tensorflow::int64> dims; 521b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar for (int i = 0; i < shape.dims(); ++i) { 522b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar dims.push_back(shape.dim_size(i)); 523b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar nelems *= shape.dim_size(i); 524b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar } 525b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar CHECK_EQ(nelems, 0); 526b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), 527b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar "64-bit int types should match in size"); 528b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()), 529b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar shape.dims(), reinterpret_cast<void*>(&empty), 0, 530b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar [](void*, size_t, void*) {}, nullptr); 531b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar} 532b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar 533f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Non-static for testing. 53496675956ef17e609d1bd60591fc998890d505004Asim ShankarTF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, 53596675956ef17e609d1bd60591fc998890d505004Asim Shankar TF_Status* status) { 53696675956ef17e609d1bd60591fc998890d505004Asim Shankar if (!src.IsInitialized()) { 53796675956ef17e609d1bd60591fc998890d505004Asim Shankar status->status = FailedPrecondition( 53896675956ef17e609d1bd60591fc998890d505004Asim Shankar "attempt to use a tensor with an uninitialized value"); 53996675956ef17e609d1bd60591fc998890d505004Asim Shankar return nullptr; 54096675956ef17e609d1bd60591fc998890d505004Asim Shankar } 541b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar if (src.NumElements() == 0) { 542b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); 543b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar } 5440b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar if (src.dtype() == DT_RESOURCE) { 5450b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar if (src.shape().dims() != 0) { 546b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar status->status = InvalidArgument( 547b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", 548b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar src.shape().DebugString(), 549b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar "). Please file a bug at " 550b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar "https://github.com/tensorflow/tensorflow/issues/new, " 551b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar "ideally with a " 552b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar "short code snippet that reproduces this error."); 553b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar return nullptr; 5540b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar } 5550b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar const string str = src.scalar<ResourceHandle>()().SerializeAsString(); 5560b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size()); 5570b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar std::memcpy(TF_TensorData(t), str.c_str(), str.size()); 5580b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar return t; 5590b3a25d684ed2466d35661683baadc7eecce73d3Asim Shankar } 5604c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (src.dtype() != DT_STRING) { 5614c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar TensorBuffer* buf = TensorCApi::Buffer(src); 5624c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar buf->Ref(); 5634c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(), 5642bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar buf}; 5654c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar } 5664c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly 5674c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar // encoded sequence of strings. 5684c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar 569f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Compute bytes needed for encoding. 570f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur size_t size = 0; 571ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const auto& srcarray = src.flat<string>(); 5727cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < srcarray.size(); ++i) { 573ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& s = srcarray(i); 574cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar // uint64 starting_offset, TF_StringEncode-d string. 575cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size()); 576f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 577f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 578f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Encode all strings. 579f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur char* base = new char[size]; 580f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size(); 581f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur char* dst = data_start; // Where next string is encoded. 582cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar size_t dst_len = size - static_cast<size_t>(data_start - base); 583f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base); 5847cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < srcarray.size(); ++i) { 585f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur *offsets = (dst - data_start); 586f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur offsets++; 587ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& s = srcarray(i); 58896675956ef17e609d1bd60591fc998890d505004Asim Shankar size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); 58996675956ef17e609d1bd60591fc998890d505004Asim Shankar if (!status->status.ok()) { 59096675956ef17e609d1bd60591fc998890d505004Asim Shankar status->status = InvalidArgument( 59196675956ef17e609d1bd60591fc998890d505004Asim Shankar "invalid string tensor encoding (string #", i, " of ", 59296675956ef17e609d1bd60591fc998890d505004Asim Shankar srcarray.size(), "): ", status->status.error_message()); 59390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? delete[] base; 59496675956ef17e609d1bd60591fc998890d505004Asim Shankar return nullptr; 59596675956ef17e609d1bd60591fc998890d505004Asim Shankar } 596cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar dst += consumed; 597cfcc6060d71d820df70fcab9a449ad45ddde72efAsim Shankar dst_len -= consumed; 598f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 59996675956ef17e609d1bd60591fc998890d505004Asim Shankar if (dst != base + size) { 60096675956ef17e609d1bd60591fc998890d505004Asim Shankar status->status = InvalidArgument( 60196675956ef17e609d1bd60591fc998890d505004Asim Shankar "invalid string tensor encoding (decoded ", (dst - base), 60296675956ef17e609d1bd60591fc998890d505004Asim Shankar " bytes, but the tensor is encoded in ", size, " bytes"); 60390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? delete[] base; 60496675956ef17e609d1bd60591fc998890d505004Asim Shankar return nullptr; 60596675956ef17e609d1bd60591fc998890d505004Asim Shankar } 606f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 607f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur auto dims = src.shape().dim_sizes(); 608f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::vector<tensorflow::int64> dimvec(dims.size()); 6097cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (size_t i = 0; i < dims.size(); ++i) { 610f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur dimvec[i] = dims[i]; 611f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 612f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), 613f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower "64-bit int types should match in size"); 614f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return TF_NewTensor(TF_STRING, 615f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower reinterpret_cast<const int64_t*>(dimvec.data()), 616f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower dimvec.size(), base, size, DeleteArray, base); 617f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 618f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 6199624d165f1f2c717eda96464fee8bf7229cc14f5Igor GanichevStatus MessageToBuffer(const tensorflow::protobuf::Message& in, 6209624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev TF_Buffer* out) { 6219624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev if (out->data != nullptr) { 6229624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return InvalidArgument("Passing non-empty TF_Buffer is invalid."); 6239624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev } 6249624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev const size_t proto_size = in.ByteSizeLong(); 6259624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev void* buf = tensorflow::port::Malloc(proto_size); 6269624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev if (buf == nullptr) { 6279624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return tensorflow::errors::ResourceExhausted( 6289624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev "Failed to allocate memory to serialize message of type '", 6299624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev in.GetTypeName(), "' and size ", proto_size); 6309624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev } 6319624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev in.SerializeToArray(buf, proto_size); 6329624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev out->data = buf; 6339624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev out->length = proto_size; 6349624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev out->data_deallocator = [](void* data, size_t length) { 6359624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev tensorflow::port::Free(data); 6369624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev }; 6379624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return Status::OK(); 6389624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev} 6399624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 640cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichevvoid RecordMutation(TF_Graph* graph, const TF_Operation& op, 641cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev const char* mutation_type) 642cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { 643cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev // If any session has already run this node_id, mark this session as 644cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev // unrunnable. 645cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev for (auto it : graph->sessions) { 646cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev if (it.first->last_num_graph_nodes > op.node.id()) { 647cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev it.second = FailedPrecondition( 648cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev "Operation '", op.node.DebugString(), "' was changed by ", 649cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev mutation_type, 650cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev " after it was run by a session. Nodes can be mutated " 651cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev "only before they are executed by a session. Either don't modify " 652cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev "nodes after running them or create a new session."); 653cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev } 654cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev } 655cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev} 656cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev 6571938feab97e36275f18a0745804299acfe137dc8Akshay Agrawalnamespace { 6581938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 6591938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal// Helper method that creates a shape handle for a shape described by dims. 6601938feab97e36275f18a0745804299acfe137dc8Akshay Agrawaltensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( 6611938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal tensorflow::shape_inference::InferenceContext* ic, int num_dims, 6621938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal const int64_t* dims) { 6631938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal if (num_dims != -1) { 6641938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; 6651938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal dim_vec.reserve(num_dims); 6661938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal for (int i = 0; i < num_dims; ++i) { 6671938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal dim_vec.push_back(ic->MakeDim(dims[i])); 6681938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal } 6691938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal return ic->MakeShape(dim_vec); 6701938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal } else { 6711938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal return ic->UnknownShape(); 6721938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal } 6731938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal} 6741938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 6751938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal} // namespace 6761938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 6771938feab97e36275f18a0745804299acfe137dc8Akshay Agrawalvoid TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, 6781938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal int num_shapes_and_types, 6791938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal const int64_t** shapes, 6801938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal const int* ranks, 6811938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal const TF_DataType* types, 6821938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal TF_Status* status) { 6831938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal Node* node = &output.oper->node; 6841938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 6851938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal mutex_lock l(graph->mu); 6861938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal tensorflow::shape_inference::InferenceContext* ic = 6871938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal graph->refiner.GetContext(node); 6881938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal if (ic == nullptr) { 6891938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal status->status = 6901938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal InvalidArgument("Node ", node->name(), " was not found in the graph"); 6911938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal return; 6921938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal } 6931938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 6941938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal auto shape_and_type_vec = 6951938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal std::vector<tensorflow::shape_inference::ShapeAndType>( 6961938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal num_shapes_and_types); 6971938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal for (int i = 0; i < num_shapes_and_types; ++i) { 6981938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal tensorflow::shape_inference::ShapeHandle shape_handle = 6991938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal ShapeHandleFromDims(ic, ranks[i], shapes[i]); 7001938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( 7011938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal shape_handle, static_cast<DataType>(types[i])); 7021938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal } 7031938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 7041938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); 7051938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal} 7061938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal 7071c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower// Helpers for loading a TensorFlow plugin (a .so file). 7081c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlowerStatus LoadLibrary(const char* library_filename, void** result, 7091c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower const void** buf, size_t* len); 7101c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower 711f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace tensorflow 712f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 7137cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlowerstatic void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, 7147cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Status* status) { 715f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur status->status = Status::OK(); 7167cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 717d83074847ebfe8871188f1f9f1e84ab0451f59e6A. Unique TensorFlower c_outputs[i] = nullptr; 718f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 7197cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower} 720f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 721ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milnestatic bool TF_Run_Inputs(TF_Tensor* const* c_inputs, 722ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<std::pair<string, Tensor>>* input_pairs, 723ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_Status* status) { 7247cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower const int ninputs = input_pairs->size(); 7257cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 7264c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); 7274c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (!status->status.ok()) return false; 728f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 729d9da9721f45950035f5087c59f9bc6910e232271Asim Shankar return true; 730d9da9721f45950035f5087c59f9bc6910e232271Asim Shankar} 731d9da9721f45950035f5087c59f9bc6910e232271Asim Shankar 7327cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlowerstatic void TF_Run_Helper( 7337cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower Session* session, const char* handle, const TF_Buffer* run_options, 7347cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // Input tensors 735ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const std::vector<std::pair<string, Tensor>>& input_pairs, 7367cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // Output tensors 737ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs, 7387cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // Target nodes 739ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const std::vector<string>& target_oper_names, TF_Buffer* run_metadata, 740ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_Status* status) { 7417cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower const int noutputs = output_tensor_names.size(); 742f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::vector<Tensor> outputs(noutputs); 7438a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower Status result; 744b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang 7458a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower if (handle == nullptr) { 7463e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang RunOptions run_options_proto; 747d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar if (run_options != nullptr && !run_options_proto.ParseFromArray( 748d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar run_options->data, run_options->length)) { 749c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Unparseable RunOptions proto"); 7503e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang return; 7513e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang } 7523e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang if (run_metadata != nullptr && run_metadata->data != nullptr) { 753c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = 754c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar InvalidArgument("Passing non-empty run_metadata is invalid."); 7553e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang return; 7563e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang } 757b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang 7583e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang RunMetadata run_metadata_proto; 7597cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower result = session->Run(run_options_proto, input_pairs, output_tensor_names, 760a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_oper_names, &outputs, &run_metadata_proto); 7613e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang 7623e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang // Serialize back to upstream client, who now owns the new buffer 7633e1ec828404d63e211c5c0ed19f29c78619ef401Zongheng Yang if (run_metadata != nullptr) { 76495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = MessageToBuffer(run_metadata_proto, run_metadata); 76595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 766b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang } 7678a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower } else { 768b54ad5788b7cec515d81e9f7e3c9811fd8181235Zongheng Yang // NOTE(zongheng): PRun does not support RunOptions yet. 7697cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); 7708a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower } 771f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (!result.ok()) { 772f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur status->status = result; 773f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return; 774f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 775f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 776f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Store results in c_outputs[] 7777cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 778f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& src = outputs[i]; 779379df09118ddfdbad19375d6d853254312ccf1aeA. Unique TensorFlower if (!src.IsInitialized() || src.NumElements() == 0) { 780b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar c_outputs[i] = 781b42ed42cdf3eb13a9412ce6ca08183d19322fd6fAsim Shankar EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); 782f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur continue; 783f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 78496675956ef17e609d1bd60591fc998890d505004Asim Shankar c_outputs[i] = TF_TensorFromTensor(src, status); 78596675956ef17e609d1bd60591fc998890d505004Asim Shankar if (!status->status.ok()) return; 786f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 787f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 788f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 7898a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlowerextern "C" { 7908a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower 791fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankarvoid TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, 7928a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Input tensors 7938a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower const char** c_input_names, TF_Tensor** c_inputs, int ninputs, 7948a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Output tensors 795a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char** c_output_names, TF_Tensor** c_outputs, int noutputs, 7968a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Target nodes 797a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char** c_target_oper_names, int ntargets, 798b49474d2794a86b4f4e27aefb0998e1c7734135bDan Smilkov TF_Buffer* run_metadata, TF_Status* status) { 7997cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Run_Setup(noutputs, c_outputs, status); 800ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<std::pair<string, Tensor>> input_pairs(ninputs); 8011f0c5119a0230c5160d45496175b9256f097e144Asim Shankar if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; 8027cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 8037cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower input_pairs[i].first = c_input_names[i]; 8047cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 805ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> output_names(noutputs); 8067cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 807a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower output_names[i] = c_output_names[i]; 8087cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 809ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> target_oper_names(ntargets); 8107cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ntargets; ++i) { 811a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_oper_names[i] = c_target_oper_names[i]; 8127cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 813a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, 814a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower c_outputs, target_oper_names, run_metadata, status); 8158a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower} 8168a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower 817fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankarvoid TF_PRunSetup(TF_DeprecatedSession* s, 8188a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Input names 8198a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower const char** c_input_names, int ninputs, 8208a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Output names 821a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char** c_output_names, int noutputs, 8228a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Target nodes 823a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char** c_target_oper_names, int ntargets, 8247cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower const char** handle, TF_Status* status) { 825348812b4a56ca460d2006abf1032f9ad9c86a084A. Unique TensorFlower *handle = nullptr; 8268a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower 827ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> input_names(ninputs); 828ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> output_names(noutputs); 829ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> target_oper_names(ntargets); 8307cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 8318a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower input_names[i] = c_input_names[i]; 8328a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower } 8337cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 834a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower output_names[i] = c_output_names[i]; 8358a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower } 8367cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ntargets; ++i) { 837a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_oper_names[i] = c_target_oper_names[i]; 8388a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower } 839ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne string new_handle; 840348812b4a56ca460d2006abf1032f9ad9c86a084A. Unique TensorFlower status->status = s->session->PRunSetup(input_names, output_names, 841348812b4a56ca460d2006abf1032f9ad9c86a084A. Unique TensorFlower target_oper_names, &new_handle); 842348812b4a56ca460d2006abf1032f9ad9c86a084A. Unique TensorFlower if (status->status.ok()) { 8437cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower char* buf = new char[new_handle.size() + 1]; 8447cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower memcpy(buf, new_handle.c_str(), new_handle.size() + 1); 8457cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower *handle = buf; 8468a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower } 8478a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower} 8488a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower 849fdd94f24e9ebfe58bd967a8c2c7b8985c03ff80bAsim Shankarvoid TF_PRun(TF_DeprecatedSession* s, const char* handle, 8508a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Input tensors 8518a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower const char** c_input_names, TF_Tensor** c_inputs, int ninputs, 8528a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Output tensors 853a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char** c_output_names, TF_Tensor** c_outputs, int noutputs, 8548a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower // Target nodes 855a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char** c_target_oper_names, int ntargets, 8568a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower TF_Status* status) { 8577cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Run_Setup(noutputs, c_outputs, status); 858ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<std::pair<string, Tensor>> input_pairs(ninputs); 8591f0c5119a0230c5160d45496175b9256f097e144Asim Shankar if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; 8607cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 8617cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower input_pairs[i].first = c_input_names[i]; 8627cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 8637cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 864ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> output_names(noutputs); 8657cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 866a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower output_names[i] = c_output_names[i]; 8677cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 868ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> target_oper_names(ntargets); 8697cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ntargets; ++i) { 870a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_oper_names[i] = c_target_oper_names[i]; 8717cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 872a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, 873a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower c_outputs, target_oper_names, nullptr, status); 8748a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower} 8758a59748c087a2fee535c0d5067dbabb01920e812A. Unique TensorFlower 8761c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlowerTF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { 8771c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower TF_Library* lib_handle = new TF_Library; 8781c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower status->status = tensorflow::LoadLibrary( 8791c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, 8801c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower &lib_handle->op_list.length); 8811c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower if (!status->status.ok()) { 8821c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower delete lib_handle; 8831c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower return nullptr; 8841c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower } 8851c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower return lib_handle; 8861c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower} 8871c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower 8881c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlowerTF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } 8891c579361cd1e088dd5e05a394b1561a73e3667baA. Unique TensorFlower 89056b003d724ecb3e7781165fd4504d058982959c1A. Unique TensorFlowervoid TF_DeleteLibraryHandle(TF_Library* lib_handle) { 89183c6e0c63acdcab2c58c4ed7220bfa58879b1d57Jonathan Hseu tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data)); 89256b003d724ecb3e7781165fd4504d058982959c1A. Unique TensorFlower delete lib_handle; 89356b003d724ecb3e7781165fd4504d058982959c1A. Unique TensorFlower} 89456b003d724ecb3e7781165fd4504d058982959c1A. Unique TensorFlower 895b7541f67c21b5a12120af0b7ac33404bd5160643Asim ShankarTF_Buffer* TF_GetAllOpList() { 896b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar std::vector<tensorflow::OpDef> op_defs; 897b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); 898b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar tensorflow::OpList op_list; 899b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar for (const auto& op : op_defs) { 900b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar *(op_list.add_op()) = op; 901b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar } 902b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar TF_Buffer* ret = TF_NewBuffer(); 903bc225bfaa534acc25047fe844f19edc333b7a76aPeter Hawkins TF_CHECK_OK(MessageToBuffer(op_list, ret)); 904b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar return ret; 905b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar} 906b7541f67c21b5a12120af0b7ac33404bd5160643Asim Shankar 9079c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta// -------------------------------------------------------------------------- 9089c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta// ListDevices & SessionListDevices API 9099c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9109c495f9499199ea46fff9028774374fa0c52e018Brennan Saetavoid TF_DeleteDeviceList(TF_DeviceList* s) { delete s; } 9119c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9129c495f9499199ea46fff9028774374fa0c52e018Brennan SaetaTF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { 9139c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta TF_DeviceList* response = new TF_DeviceList; 9149c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta status->status = session->session->ListDevices(&response->response); 9159c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return response; 9169c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta} 9179c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9189c495f9499199ea46fff9028774374fa0c52e018Brennan SaetaTF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, 9199c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta TF_Status* status) { 9209c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta TF_DeviceList* response = new TF_DeviceList; 9219c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta status->status = session->session->ListDevices(&response->response); 9229c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return response; 9239c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta} 9249c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9259c495f9499199ea46fff9028774374fa0c52e018Brennan Saetaint TF_DeviceListCount(const TF_DeviceList* list) { 9269c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return list->response.size(); 9279c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta} 9289c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9299c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ 9309c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return_type method_name(const TF_DeviceList* list, const int index, \ 9319c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta TF_Status* status) { \ 9329c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta if (list == nullptr) { \ 9339c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta status->status = InvalidArgument("list is null!"); \ 9349c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return err_val; \ 9359c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta } \ 9369c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta if (index < 0 || index >= list->response.size()) { \ 9379c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta status->status = InvalidArgument("index out of bounds"); \ 9389c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return err_val; \ 9399c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta } \ 940b3294d74c73d8f01412ffeeac616d0b8a2f4e18eA. Unique TensorFlower status->status = Status::OK(); \ 9419c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta return list->response[index].accessor; \ 9429c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta } 9439c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9449c495f9499199ea46fff9028774374fa0c52e018Brennan SaetaTF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); 9459c495f9499199ea46fff9028774374fa0c52e018Brennan SaetaTF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), 9469c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta nullptr); 9479c495f9499199ea46fff9028774374fa0c52e018Brennan SaetaTF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); 9489c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 9499c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta#undef TF_DEVICELIST_METHOD 9509c495f9499199ea46fff9028774374fa0c52e018Brennan Saeta 951f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end extern "C" 952f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 953f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower// -------------------------------------------------------------------------- 954f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower// New Graph and Session API 955f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 956f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower// Helper functions ----------------------------------------------------------- 957f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 958f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowernamespace { 959f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 960a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerTF_Operation* ToOperation(Node* node) { 961a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return static_cast<TF_Operation*>(static_cast<void*>(node)); 962f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 963f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 964ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milnestring OutputName(const TF_Output& output) { 965661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return StrCat(output.oper->node.name(), ":", output.index); 966f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 967f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 96895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarconst tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, 96995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const char* attr_name, 97095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Status* status) { 97173882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); 97295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (attr == nullptr) { 9736c95675492aa8d25619f5e4ce1674582c051a7feSkye Wanderman-Milne status->status = InvalidArgument("Operation '", oper->node.name(), 9746c95675492aa8d25619f5e4ce1674582c051a7feSkye Wanderman-Milne "' has no attr named '", attr_name, "'."); 97595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 97695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return attr; 97795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 97895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 9790fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneTensorId ToTensorId(const TF_Output& output) { 9800fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return TensorId(output.oper->node.name(), output.index); 9810fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne} 9820fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 9830fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#ifndef __ANDROID__ 9840fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnestd::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs, 9850fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne int n) { 9860fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output> outputs(n); 9870fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne for (int i = 0; i < n; ++i) { 9880fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne outputs[i] = 9890fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index); 9900fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne } 9910fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return outputs; 9920fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne} 9930fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 9940fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnevoid TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs, 9950fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne TF_Output* tf_outputs) { 9960fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne for (int i = 0; i < outputs.size(); i++) { 9970fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tf_outputs[i].oper = ToOperation(outputs[i].node()); 9980fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tf_outputs[i].index = outputs[i].index(); 9990fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne } 10000fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne} 10010fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#endif // __ANDROID__ 10020fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 1003f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} // namespace 1004f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 10052677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan// Shape functions ----------------------------------------------------------- 10062677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10078f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuvoid TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, 10088f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu const int64_t* dims, const int num_dims, 10098f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu TF_Status* status) { 10108f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu Node* node = &output.oper->node; 10112677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10122677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan mutex_lock l(graph->mu); 10132677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan tensorflow::shape_inference::InferenceContext* ic = 10142677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan graph->refiner.GetContext(node); 10152677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (ic == nullptr) { 1016c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = 1017c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar InvalidArgument("Node ", node->name(), " was not found in the graph"); 10182677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return; 10192677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10201938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal tensorflow::shape_inference::ShapeHandle new_shape = 10211938feab97e36275f18a0745804299acfe137dc8Akshay Agrawal tensorflow::ShapeHandleFromDims(ic, num_dims, dims); 10228f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu status->status = graph->refiner.SetShape(node, output.index, new_shape); 10232677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan} 10242677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10258f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuint TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, 10268f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu TF_Status* status) { 10278f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu Node* node = &output.oper->node; 10282677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10292677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan mutex_lock l(graph->mu); 10302677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan tensorflow::shape_inference::InferenceContext* ic = 10312677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan graph->refiner.GetContext(node); 10322677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (ic == nullptr) { 1033c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = 1034c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar InvalidArgument("Node ", node->name(), " was not found in the graph"); 10352677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return -1; 10362677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10372677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10388f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); 10392677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10402677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan // Unknown rank means the number of dimensions is -1. 10412677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (!ic->RankKnown(shape)) { 10422677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return -1; 10432677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10442677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10452677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return ic->Rank(shape); 10462677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan} 10472677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10488f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuvoid TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, 10492677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan int num_dims, TF_Status* status) { 10508f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu Node* node = &output.oper->node; 10512677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10522677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan mutex_lock l(graph->mu); 10532677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan tensorflow::shape_inference::InferenceContext* ic = 10542677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan graph->refiner.GetContext(node); 10552677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (ic == nullptr) { 1056c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = 1057c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar InvalidArgument("Node ", node->name(), " was not found in the graph"); 10582677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return; 10592677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10602677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10618f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); 10622677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10632677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan int rank = -1; 10642677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (ic->RankKnown(shape)) { 10652677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan rank = ic->Rank(shape); 10662677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10672677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10682677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (num_dims != rank) { 1069c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Expected rank is ", num_dims, 1070c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar " but actual rank is ", rank); 10712677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return; 10722677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10732677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10742677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (num_dims == 0) { 10752677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan // Output shape is a scalar. 10762677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan return; 10772677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10782677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 10792677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan // Rank is greater than 0, so fill in the values, if known, and 10802677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan // -1 for unknown values. 10812677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan for (int i = 0; i < num_dims; ++i) { 10822677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan auto dim = ic->Dim(shape, i); 10832677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan tensorflow::int64 value = -1; 10842677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan if (ic->ValueKnown(dim)) { 10852677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan value = ic->Value(dim); 10862677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10872677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan dims[i] = value; 10882677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan } 10892677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan} 10902677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 1091457cc412bbd0caea7f11b6690336048ed7455fa6A. Unique TensorFlower// TF_OperationDescription functions ------------------------------------------ 1092f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1093f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerextern "C" { 1094f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1095661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnestatic TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, 1096661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne const char* op_type, 1097661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne const char* oper_name) 1098661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { 1099661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return new TF_OperationDescription(graph, op_type, oper_name); 1100661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 1101661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 1102a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerTF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, 1103a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const char* oper_name) { 1104f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower mutex_lock l(graph->mu); 1105661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return TF_NewOperationLocked(graph, op_type, oper_name); 1106f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1107f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1108a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetDevice(TF_OperationDescription* desc, const char* device) { 1109f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Device(device); 1110f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1111f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 11128f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuvoid TF_AddInput(TF_OperationDescription* desc, TF_Output input) { 1113a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower desc->node_builder.Input(&input.oper->node, input.index); 1114f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1115f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 11168f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuvoid TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs, 1117f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int num_inputs) { 1118f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower std::vector<NodeBuilder::NodeOut> input_list; 1119f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower input_list.reserve(num_inputs); 1120f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < num_inputs; ++i) { 1121a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); 1122f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1123f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Input(input_list); 1124f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1125f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1126a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { 1127f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.ControlInput(&input->node); 1128f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1129f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 113034501544c48061bdfb6a0e7cddf1f70136cf3040Asim Shankarvoid TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { 113157626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->colocation_constraints.emplace( 1132661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); 113334501544c48061bdfb6a0e7cddf1f70136cf3040Asim Shankar} 113434501544c48061bdfb6a0e7cddf1f70136cf3040Asim Shankar 1135a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, 1136786938758a39940ae0154834bfed9e21894afa28Asim Shankar const void* value, size_t length) { 1137f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower tensorflow::StringPiece s(static_cast<const char*>(value), length); 1138f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, s); 1139f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1140f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1141a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, 1142786938758a39940ae0154834bfed9e21894afa28Asim Shankar const void* const* values, const size_t* lengths, 1143f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int num_values) { 114457626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { 114557626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->colocation_constraints.clear(); 114657626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower for (int i = 0; i < num_values; ++i) { 114757626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->colocation_constraints.emplace(static_cast<const char*>(values[i]), 114857626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower lengths[i]); 114957626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } 115057626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } else { 115157626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower std::vector<tensorflow::StringPiece> v; 115257626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower v.reserve(num_values); 115357626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower for (int i = 0; i < num_values; ++i) { 115457626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower v.emplace_back(static_cast<const char*>(values[i]), lengths[i]); 115557626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } 115657626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->node_builder.Attr(attr_name, v); 1157f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1158f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1159f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1160a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, 1161f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int64_t value) { 1162f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), 1163f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower "64-bit int types should match in size"); 1164f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value)); 1165f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1166f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1167a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, 1168f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const int64_t* values, int num_values) { 1169f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), 1170f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower "64-bit int types should match in size"); 1171f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr( 1172f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower attr_name, 1173f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ArraySlice<const tensorflow::int64>( 1174f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower reinterpret_cast<const tensorflow::int64*>(values), num_values)); 1175f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1176f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1177a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, 1178f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower float value) { 1179f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, value); 1180f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1181f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1182a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, 1183f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const float* values, int num_values) { 1184f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, 1185f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ArraySlice<const float>(values, num_values)); 1186f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1187f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1188a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, 1189f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower unsigned char value) { 1190f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, static_cast<bool>(value)); 1191f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1192f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1193a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, 1194f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const unsigned char* values, int num_values) { 1195f30403181f0ccbb30bb38777fbccf6c1b6576ca9A. Unique TensorFlower std::unique_ptr<bool[]> b(new bool[num_values]); 1196f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < num_values; ++i) { 1197f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower b[i] = values[i]; 1198f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1199f30403181f0ccbb30bb38777fbccf6c1b6576ca9A. Unique TensorFlower desc->node_builder.Attr(attr_name, 1200f30403181f0ccbb30bb38777fbccf6c1b6576ca9A. Unique TensorFlower ArraySlice<const bool>(b.get(), num_values)); 1201f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1202f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1203a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, 1204f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TF_DataType value) { 1205f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, static_cast<DataType>(value)); 1206f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1207f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1208a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, 1209f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const TF_DataType* values, int num_values) { 1210f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr( 1211f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower attr_name, ArraySlice<const DataType>( 1212f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower reinterpret_cast<const DataType*>(values), num_values)); 1213f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1214f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1215afbf1e3ab3cd1ba0ebec53483c9d3a05b9c51554A. Unique TensorFlowervoid TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, 1216afbf1e3ab3cd1ba0ebec53483c9d3a05b9c51554A. Unique TensorFlower const char* value, size_t length) { 1217baea13831c2d1ffa08c4fcc8944a3870d19826cbA. Unique TensorFlower tensorflow::NameAttrList func_name; 1218baea13831c2d1ffa08c4fcc8944a3870d19826cbA. Unique TensorFlower func_name.set_name(std::string(value, value + length)); 1219baea13831c2d1ffa08c4fcc8944a3870d19826cbA. Unique TensorFlower desc->node_builder.Attr(attr_name, func_name); 1220baea13831c2d1ffa08c4fcc8944a3870d19826cbA. Unique TensorFlower} 1221baea13831c2d1ffa08c4fcc8944a3870d19826cbA. Unique TensorFlower 1222a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, 1223f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const int64_t* dims, int num_dims) { 1224f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower PartialTensorShape shape; 1225f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (num_dims >= 0) { 1226f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), 1227f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower "64-bit int types should match in size"); 1228f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower shape = PartialTensorShape(ArraySlice<tensorflow::int64>( 1229f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower reinterpret_cast<const tensorflow::int64*>(dims), num_dims)); 1230f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1231f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, shape); 1232f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1233f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1234a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, 1235f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const int64_t* const* dims, const int* num_dims, 1236f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int num_shapes) { 1237f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower std::vector<PartialTensorShape> shapes; 1238f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower shapes.reserve(num_shapes); 1239f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < num_shapes; ++i) { 1240f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (num_dims[i] < 0) { 1241f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower shapes.emplace_back(); 1242f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 1243f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), 1244f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower "64-bit int types should match in size"); 1245f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower shapes.emplace_back(ArraySlice<tensorflow::int64>( 1246f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i])); 1247f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1248f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1249f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, shapes); 1250f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1251f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1252a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, 1253a559f91041f144820266fc0a751e09267acf4428A. Unique TensorFlower const char* attr_name, const void* proto, 1254786938758a39940ae0154834bfed9e21894afa28Asim Shankar size_t proto_len, TF_Status* status) { 1255786938758a39940ae0154834bfed9e21894afa28Asim Shankar // shape.ParseFromArray takes an int as length, this function takes size_t, 1256786938758a39940ae0154834bfed9e21894afa28Asim Shankar // make sure there is no information loss. 1257786938758a39940ae0154834bfed9e21894afa28Asim Shankar if (proto_len > std::numeric_limits<int>::max()) { 1258786938758a39940ae0154834bfed9e21894afa28Asim Shankar status->status = InvalidArgument( 1259786938758a39940ae0154834bfed9e21894afa28Asim Shankar "proto_len (", proto_len, 1260786938758a39940ae0154834bfed9e21894afa28Asim Shankar " bytes) is too large to be parsed by the protocol buffer library"); 1261786938758a39940ae0154834bfed9e21894afa28Asim Shankar return; 1262786938758a39940ae0154834bfed9e21894afa28Asim Shankar } 1263f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TensorShapeProto shape; 1264786938758a39940ae0154834bfed9e21894afa28Asim Shankar if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) { 1265f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, shape); 1266f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = Status::OK(); 1267f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 1268c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Unparseable TensorShapeProto"); 1269f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1270f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1271f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1272a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, 1273f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const char* attr_name, 1274f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const void* const* protos, 1275786938758a39940ae0154834bfed9e21894afa28Asim Shankar const size_t* proto_lens, int num_shapes, 1276f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TF_Status* status) { 1277f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower std::vector<TensorShapeProto> shapes; 1278f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower shapes.resize(num_shapes); 1279f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < num_shapes; ++i) { 1280786938758a39940ae0154834bfed9e21894afa28Asim Shankar if (proto_lens[i] > std::numeric_limits<int>::max()) { 1281786938758a39940ae0154834bfed9e21894afa28Asim Shankar status->status = InvalidArgument( 1282786938758a39940ae0154834bfed9e21894afa28Asim Shankar "length of element ", i, " in the list (", proto_lens[i], 1283786938758a39940ae0154834bfed9e21894afa28Asim Shankar " bytes) is too large to be parsed by the protocol buffer library"); 1284786938758a39940ae0154834bfed9e21894afa28Asim Shankar return; 1285786938758a39940ae0154834bfed9e21894afa28Asim Shankar } 1286786938758a39940ae0154834bfed9e21894afa28Asim Shankar if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) { 1287c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = 1288c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar InvalidArgument("Unparseable TensorShapeProto at index ", i); 1289f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return; 1290f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1291f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1292f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->node_builder.Attr(attr_name, shapes); 1293f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = Status::OK(); 1294f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1295f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1296a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, 1297f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TF_Tensor* value, TF_Status* status) { 1298f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower Tensor t; 12994c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar status->status = TF_TensorToTensor(value, &t); 13004c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (status->status.ok()) desc->node_builder.Attr(attr_name, t); 1301f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1302f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1303a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowervoid TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, 1304f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TF_Tensor* const* values, int num_values, 1305f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TF_Status* status) { 1306f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = Status::OK(); 1307f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower std::vector<Tensor> t; 1308f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower t.reserve(num_values); 1309f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 13104c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar for (int i = 0; i < num_values && status->status.ok(); ++i) { 13114c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar Tensor v; 13124c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar status->status = TF_TensorToTensor(values[i], &v); 13134c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar t.emplace_back(v); 1314f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1315f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 13164c9e344bf1b6582620b26c0a62a886d3c80e3c19Asim Shankar if (status->status.ok()) desc->node_builder.Attr(attr_name, t); 1317f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1318f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 131995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, 132095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const void* proto, size_t proto_len, 132195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Status* status) { 1322f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower tensorflow::AttrValue attr_value; 132357626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower if (!attr_value.ParseFromArray(proto, proto_len)) { 1324c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Unparseable AttrValue proto"); 132557626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower return; 1326f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 132757626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower 132857626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { 132957626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower if (attr_value.value_case() != tensorflow::AttrValue::kList && 133057626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { 133157626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower status->status = 133257626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower InvalidArgument("Expected \"list\" field for \"", 133357626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower tensorflow::kColocationAttrName, "\" attribute"); 133457626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower return; 133557626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } 133657626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->colocation_constraints.clear(); 1337ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne for (const string& location : attr_value.list().s()) { 133857626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->colocation_constraints.insert(location); 133957626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } 134057626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } else { 134157626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->node_builder.Attr(attr_name, attr_value); 134257626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } 134357626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower 134457626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower status->status = Status::OK(); 1345f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1346f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1347661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnestatic TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, 1348661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Status* status) 1349661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { 1350f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower Node* ret = nullptr; 1351f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1352f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (desc->graph->name_map.count(desc->node_builder.node_name())) { 1353c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Duplicate node name in graph: '", 1354c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar desc->node_builder.node_name(), "'"); 1355f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 135657626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower if (!desc->colocation_constraints.empty()) { 135757626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower desc->node_builder.Attr( 135857626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower tensorflow::kColocationAttrName, 1359ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string>(desc->colocation_constraints.begin(), 1360ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne desc->colocation_constraints.end())); 136157626dd38a7867b76c44f3933e7810190174a2eeA. Unique TensorFlower } 1362f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); 13632677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan 1364f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (status->status.ok()) { 13652677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan // Run shape inference function for newly added node. 13667d785f1e18af9d22d940f18aac6e8c9ffd268b22Asim Shankar status->status = desc->graph->refiner.AddNode(ret); 13677d785f1e18af9d22d940f18aac6e8c9ffd268b22Asim Shankar } 13687d785f1e18af9d22d940f18aac6e8c9ffd268b22Asim Shankar if (status->status.ok()) { 13692677d3be19952488662a22cf9f42374a493ffd50Vijay Vasudevan // Add the node to the name-to-node mapping. 1370f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower desc->graph->name_map[ret->name()] = ret; 13717d785f1e18af9d22d940f18aac6e8c9ffd268b22Asim Shankar } else if (ret != nullptr) { 13727d785f1e18af9d22d940f18aac6e8c9ffd268b22Asim Shankar desc->graph->graph.RemoveNode(ret); 13737d785f1e18af9d22d940f18aac6e8c9ffd268b22Asim Shankar ret = nullptr; 1374f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1375f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1376f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1377f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower delete desc; 1378f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1379a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return ToOperation(ret); 1380f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1381f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1382661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-MilneTF_Operation* TF_FinishOperation(TF_OperationDescription* desc, 1383661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Status* status) { 1384661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne mutex_lock l(desc->graph->mu); 1385661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return TF_FinishOperationLocked(desc, status); 1386661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 1387661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 1388a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower// TF_Operation functions 1389a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower// ---------------------------------------------------------- 1390f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1391a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerconst char* TF_OperationName(TF_Operation* oper) { 1392a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return oper->node.name().c_str(); 1393a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower} 1394f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1395a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerconst char* TF_OperationOpType(TF_Operation* oper) { 1396a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return oper->node.type_string().c_str(); 1397f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1398f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1399a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerconst char* TF_OperationDevice(TF_Operation* oper) { 140073882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving return oper->node.requested_device().c_str(); 1401f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1402f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1403a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationNumOutputs(TF_Operation* oper) { 1404a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return oper->node.num_outputs(); 1405a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower} 1406f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 14078f6cb22c675b5c0b553334a8f04daef462905d69Jonathan HseuTF_DataType TF_OperationOutputType(TF_Output oper_out) { 1408f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return static_cast<TF_DataType>( 1409a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower oper_out.oper->node.output_type(oper_out.index)); 1410f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1411f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1412a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, 1413a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower TF_Status* status) { 1414078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower NameRangeMap name_ranges; 141573882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = 141673882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); 1417078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower if (!status->status.ok()) return -1; 1418078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower auto iter = name_ranges.find(arg_name); 1419078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower if (iter == name_ranges.end()) { 1420c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Input arg '", arg_name, "' not found"); 1421078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower return -1; 1422078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower } 1423078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower return iter->second.second - iter->second.first; 1424078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower} 1425078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower 1426a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationNumInputs(TF_Operation* oper) { 1427a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return oper->node.num_inputs(); 1428a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower} 1429f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 14308f6cb22c675b5c0b553334a8f04daef462905d69Jonathan HseuTF_DataType TF_OperationInputType(TF_Input oper_in) { 1431a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index)); 1432f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1433f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1434a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, 1435a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower TF_Status* status) { 1436078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower NameRangeMap name_ranges; 143773882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = 143873882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); 1439078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower if (!status->status.ok()) return -1; 1440078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower auto iter = name_ranges.find(arg_name); 1441078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower if (iter == name_ranges.end()) { 1442c0169dc34a99d8541bd420ddf7b73e1e37dfbf19Asim Shankar status->status = InvalidArgument("Input arg '", arg_name, "' not found"); 1443078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower return -1; 1444078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower } 1445078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower return iter->second.second - iter->second.first; 1446078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower} 1447078d4a2828360728e4424b3ef057808d9185a87eA. Unique TensorFlower 14488f6cb22c675b5c0b553334a8f04daef462905d69Jonathan HseuTF_Output TF_OperationInput(TF_Input oper_in) { 14497a95d4a88eda1ab94b2cab1eed48fb82f9879b65Vijay Vasudevan const tensorflow::Edge* edge; 14507a95d4a88eda1ab94b2cab1eed48fb82f9879b65Vijay Vasudevan Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); 14517a95d4a88eda1ab94b2cab1eed48fb82f9879b65Vijay Vasudevan if (!s.ok()) { 14527a95d4a88eda1ab94b2cab1eed48fb82f9879b65Vijay Vasudevan return {nullptr, -1}; 1453f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 14547a95d4a88eda1ab94b2cab1eed48fb82f9879b65Vijay Vasudevan 14557a95d4a88eda1ab94b2cab1eed48fb82f9879b65Vijay Vasudevan return {ToOperation(edge->src()), edge->src_output()}; 1456f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1457f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 14588f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuint TF_OperationOutputNumConsumers(TF_Output oper_out) { 1459f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int count = 0; 1460a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower for (const auto* edge : oper_out.oper->node.out_edges()) { 1461a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower if (edge->src_output() == oper_out.index) { 1462f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ++count; 1463f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1464f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1465f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return count; 1466f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1467f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 14688f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuint TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, 1469a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower int max_consumers) { 1470f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int count = 0; 1471a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower for (const auto* edge : oper_out.oper->node.out_edges()) { 1472a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower if (edge->src_output() == oper_out.index) { 1473f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (count < max_consumers) { 1474a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; 1475f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1476f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ++count; 1477f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1478f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1479f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return count; 1480f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1481f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1482a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationNumControlInputs(TF_Operation* oper) { 1483d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne int count = 0; 1484d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne for (const auto* edge : oper->node.in_edges()) { 1485d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne if (edge->IsControlEdge() && !edge->src()->IsSource()) { 1486d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne ++count; 1487d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne } 1488d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne } 1489d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne return count; 1490f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1491f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1492a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationGetControlInputs(TF_Operation* oper, 1493a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower TF_Operation** control_inputs, 1494a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower int max_control_inputs) { 1495f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int count = 0; 1496a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower for (const auto* edge : oper->node.in_edges()) { 1497d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne if (edge->IsControlEdge() && !edge->src()->IsSource()) { 1498f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (count < max_control_inputs) { 1499a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower control_inputs[count] = ToOperation(edge->src()); 1500f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1501f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ++count; 1502f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1503f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1504f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return count; 1505f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1506f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1507a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationNumControlOutputs(TF_Operation* oper) { 1508f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int count = 0; 1509a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower for (const auto* edge : oper->node.out_edges()) { 1510d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne if (edge->IsControlEdge() && !edge->dst()->IsSink()) { 1511f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ++count; 1512f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1513f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1514f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return count; 1515f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1516f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1517a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerint TF_OperationGetControlOutputs(TF_Operation* oper, 1518a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower TF_Operation** control_outputs, 1519a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower int max_control_outputs) { 1520f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower int count = 0; 1521a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower for (const auto* edge : oper->node.out_edges()) { 1522d50f68df6d1e4e41f4486ee71626454f6bd3ffe4Skye Wanderman-Milne if (edge->IsControlEdge() && !edge->dst()->IsSink()) { 1523f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (count < max_control_outputs) { 1524a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower control_outputs[count] = ToOperation(edge->dst()); 1525f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1526f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower ++count; 1527f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1528f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1529f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return count; 1530f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1531f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1532c51a29a43a7c2b4a7bab39282b01b46bfceac498Asim ShankarTF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, 1533c51a29a43a7c2b4a7bab39282b01b46bfceac498Asim Shankar const char* attr_name, 1534c51a29a43a7c2b4a7bab39282b01b46bfceac498Asim Shankar TF_Status* status) { 1535c51a29a43a7c2b4a7bab39282b01b46bfceac498Asim Shankar TF_AttrMetadata metadata; 153695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto* attr = GetAttrValue(oper, attr_name, status); 153795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return metadata; 153895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar switch (attr->value_case()) { 153995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar#define SINGLE_CASE(kK, attr_type, size_expr) \ 154095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar case tensorflow::AttrValue::kK: \ 154195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.is_list = 0; \ 154295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.list_size = -1; \ 154395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = attr_type; \ 154495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.total_size = size_expr; \ 154595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar break; 154695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 154795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); 154895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kI, TF_ATTR_INT, -1); 154995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); 155095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kB, TF_ATTR_BOOL, -1); 155195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kType, TF_ATTR_TYPE, -1); 155295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kShape, TF_ATTR_SHAPE, 155395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); 155495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); 155595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar#undef SINGLE_CASE 155695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 155795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar case tensorflow::AttrValue::kList: 155895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.is_list = 1; 155995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.list_size = 0; 156095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.total_size = -1; 156195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar#define LIST_CASE(field, attr_type, ...) \ 156295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (attr->list().field##_size() > 0) { \ 156395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = attr_type; \ 156495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.list_size = attr->list().field##_size(); \ 156595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar __VA_ARGS__; \ 156695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar break; \ 156795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 156895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 156995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; 157095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < attr->list().s_size(); 157195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar ++i) { metadata.total_size += attr->list().s(i).size(); }); 157295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(i, TF_ATTR_INT); 157395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(f, TF_ATTR_FLOAT); 157495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(b, TF_ATTR_BOOL); 157595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(type, TF_ATTR_TYPE); 157695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; 157795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < attr->list().shape_size(); ++i) { 157895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto& s = attr->list().shape(i); 157995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); 158095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar }); 158195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar LIST_CASE(tensor, TF_ATTR_TENSOR); 1582d4a9d91bc68ffa9f0148ae6fe344b7d3e3de7221Peter Hawkins LIST_CASE(tensor, TF_ATTR_FUNC); 158395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar#undef LIST_CASE 158495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar // All lists empty, determine the type from the OpDef. 158595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (metadata.list_size == 0) { 158695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { 158795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto& a = oper->node.op_def().attr(i); 158895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (a.name().compare(attr_name) != 0) continue; 1589ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& typestr = a.type(); 159095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (typestr == "list(string)") { 159195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_STRING; 159295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else if (typestr == "list(int)") { 159395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_INT; 159495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else if (typestr == "list(float)") { 159595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_FLOAT; 159695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else if (typestr == "list(bool)") { 159795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_BOOL; 159895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else if (typestr == "list(type)") { 159995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_TYPE; 160095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else if (typestr == "list(shape)") { 160195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_SHAPE; 160295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else if (typestr == "list(tensor)") { 160395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_TENSOR; 1604d4a9d91bc68ffa9f0148ae6fe344b7d3e3de7221Peter Hawkins } else if (typestr == "list(func)") { 1605d4a9d91bc68ffa9f0148ae6fe344b7d3e3de7221Peter Hawkins metadata.type = TF_ATTR_FUNC; 160695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } else { 160795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = InvalidArgument( 160895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar "Attribute '", attr_name, 160995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar "' has an empty value of an unrecognized type '", typestr, "'"); 161095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return metadata; 161195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 161295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 161395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 161495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar break; 161595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 161695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar case tensorflow::AttrValue::kPlaceholder: 161795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.is_list = 0; 161895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.list_size = -1; 161995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_PLACEHOLDER; 162095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.total_size = -1; 162195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar break; 162295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 162395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar case tensorflow::AttrValue::kFunc: 162495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.is_list = 0; 162595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.list_size = -1; 162695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.type = TF_ATTR_FUNC; 162795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar metadata.total_size = -1; 162895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar break; 162995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 163095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar case tensorflow::AttrValue::VALUE_NOT_SET: 163195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = 163295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar InvalidArgument("Attribute '", attr_name, "' has no value set"); 163395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar break; 163495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 163595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return metadata; 163695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 163795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 163895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, 1639786938758a39940ae0154834bfed9e21894afa28Asim Shankar void* value, size_t max_length, 1640786938758a39940ae0154834bfed9e21894afa28Asim Shankar TF_Status* status) { 164195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto* attr = GetAttrValue(oper, attr_name, status); 164295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 164395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (attr->value_case() != tensorflow::AttrValue::kS) { 164495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = 164595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar InvalidArgument("Attribute '", attr_name, "' is not a string"); 16467cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return; 16477cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 164895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (max_length <= 0) { 164995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return; 165095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 165195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto& s = attr->s(); 165295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length)); 165395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 16547cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 165595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, 1656786938758a39940ae0154834bfed9e21894afa28Asim Shankar void** values, size_t* lengths, 1657786938758a39940ae0154834bfed9e21894afa28Asim Shankar int max_values, void* storage, 1658786938758a39940ae0154834bfed9e21894afa28Asim Shankar size_t storage_size, TF_Status* status) { 165995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto* attr = GetAttrValue(oper, attr_name, status); 166095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 166195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (attr->value_case() != tensorflow::AttrValue::kList) { 166295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = 166395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar InvalidArgument("Value for '", attr_name, "' is not a list"); 16647cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return; 16657cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 166695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto len = std::min(max_values, attr->list().s_size()); 166795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar char* p = static_cast<char*>(storage); 166895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < len; ++i) { 1669ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& s = attr->list().s(i); 167095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar values[i] = p; 167195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar lengths[i] = s.size(); 167295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) { 167395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = InvalidArgument( 167495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar "Not enough storage to hold the requested list of strings"); 167595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return; 167695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 167795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar memcpy(values[i], s.data(), s.size()); 167895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar p += s.size(); 167995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 168095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 16817cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 168273882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ 168373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving void func(TF_Operation* oper, const char* attr_name, c_type* value, \ 168473882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving TF_Status* status) { \ 168573882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving cpp_type v; \ 168673882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = \ 168773882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ 168873882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving *value = static_cast<c_type>(v); \ 168973882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving } \ 169073882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ 169173882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving int max_values, TF_Status* status) { \ 169273882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving const auto* attr = GetAttrValue(oper, attr_name, status); \ 169373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving if (!status->status.ok()) return; \ 169473882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving if (attr->value_case() != tensorflow::AttrValue::kList) { \ 169573882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = \ 169673882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving InvalidArgument("Value for '", attr_name, "' is not a list."); \ 169773882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving return; \ 169873882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving } \ 169973882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving const auto len = std::min(max_values, attr->list().list_field##_size()); \ 170073882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving for (int i = 0; i < len; ++i) { \ 170173882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving values[i] = static_cast<c_type>(attr->list().list_field(i)); \ 170273882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving } \ 170395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 170495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim ShankarDEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); 170595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim ShankarDEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); 170695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim ShankarDEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); 170795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim ShankarDEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); 170895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar#undef DEFINE_GETATTR 170995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 171095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, 171195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int64_t* value, int num_dims, TF_Status* status) { 171295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar PartialTensorShape shape; 171373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = 171473882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); 171595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 171695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar auto len = std::min(shape.dims(), num_dims); 171795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < len; ++i) { 171895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar value[i] = shape.dim_size(i); 171995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 17207cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower} 17217cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 172295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, 172395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int64_t** values, int* num_dims, 172495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int max_values, int64_t* storage, 172595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int storage_size, TF_Status* status) { 172695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar std::vector<PartialTensorShape> shapes; 172795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = 172873882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); 172995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 173095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar auto len = std::min(static_cast<int>(shapes.size()), max_values); 173195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int64_t* p = storage; 173295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int storage_left = storage_size; 173395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < len; ++i) { 173495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar // shapes[i].dims() == -1 for shapes with an unknown rank. 173595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar int64_t n = shapes[i].dims(); 173695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar num_dims[i] = n; 173795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar values[i] = p; 173895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (n < 0) { 173995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar continue; 174095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 174195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (storage_left < n) { 174295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = InvalidArgument( 174395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar "Not enough storage to hold the requested list of shapes"); 174495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return; 174595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 174695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar storage_left -= n; 174795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int j = 0; j < n; ++j, ++p) { 174895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar *p = shapes[i].dim_size(j); 174995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 175095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 175195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 175295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 175395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, 175495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const char* attr_name, 175595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Buffer* value, TF_Status* status) { 175695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto* attr = GetAttrValue(oper, attr_name, status); 175795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 175895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (attr->value_case() != tensorflow::AttrValue::kShape) { 175995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = 176095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar InvalidArgument("Value for '", attr_name, "' is not a shape."); 176195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return; 176295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 176395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = MessageToBuffer(attr->shape(), value); 176495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 176595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 176695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, 176795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const char* attr_name, 176895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Buffer** values, int max_values, 176995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Status* status) { 177095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto* attr = GetAttrValue(oper, attr_name, status); 177195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 177295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (attr->value_case() != tensorflow::AttrValue::kList) { 177395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = 177495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar InvalidArgument("Value for '", attr_name, "' is not a list"); 1775f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return; 1776f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 177795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto len = std::min(max_values, attr->list().shape_size()); 177895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < len; ++i) { 177995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar values[i] = TF_NewBuffer(); 178095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = MessageToBuffer(attr->list().shape(i), values[i]); 178195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) { 178295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar // Delete everything allocated to far, the operation has failed. 178395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int j = 0; j <= i; ++j) { 178495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_DeleteBuffer(values[j]); 178595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 178695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar return; 178795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 178895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 178995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 1790f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 179195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, 179295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Tensor** value, TF_Status* status) { 179395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar *value = nullptr; 179495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar Tensor t; 179573882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); 179695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 179796675956ef17e609d1bd60591fc998890d505004Asim Shankar *value = TF_TensorFromTensor(t, status); 179895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 179995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 180095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, 180195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Tensor** values, int max_values, 180295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Status* status) { 180395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar std::vector<Tensor> ts; 180473882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); 180595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 180695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto len = std::min(max_values, static_cast<int>(ts.size())); 180795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar for (int i = 0; i < len; ++i) { 180896675956ef17e609d1bd60591fc998890d505004Asim Shankar values[i] = TF_TensorFromTensor(ts[i], status); 180995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar } 181095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 181195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 181295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, 181395ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Buffer* output_attr_value, 181495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Status* status) { 181595ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar const auto* attr = GetAttrValue(oper, attr_name, status); 181695ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar if (!status->status.ok()) return; 181795ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = MessageToBuffer(*attr, output_attr_value); 181895ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar} 181995ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar 182095ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankarvoid TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, 182195ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar TF_Status* status) { 182295ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = MessageToBuffer(oper->node.def(), output_node_def); 1823f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1824f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1825f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower// TF_Graph functions --------------------------------------------------------- 1826f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1827e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey IrvingTF_Graph::TF_Graph() 1828e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving : graph(tensorflow::OpRegistry::Global()), 1829e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving refiner(graph.versions().producer(), graph.op_registry()), 1830e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving delete_requested(false), 1831e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving parent(nullptr), 1832e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving parent_inputs(nullptr) {} 1833e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving 1834f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowerTF_Graph* TF_NewGraph() { return new TF_Graph; } 1835f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1836f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowervoid TF_DeleteGraph(TF_Graph* g) { 1837f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower g->mu.lock(); 1838f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower g->delete_requested = true; 1839cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev const bool del = g->sessions.empty(); 1840f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower g->mu.unlock(); 1841f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (del) delete g; 1842f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1843f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1844a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerTF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { 1845f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower mutex_lock l(graph->mu); 1846a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower auto iter = graph->name_map.find(oper_name); 1847f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (iter == graph->name_map.end()) { 1848f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return nullptr; 1849f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 1850a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower return ToOperation(iter->second); 1851f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1852f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1853f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1854a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlowerTF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { 1855f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (*pos == 0) { 1856edaf3b342db4afa1c872da541fb0ac176a4e8ef9A. Unique TensorFlower // Advance past the first sentinel nodes in every graph (the source & sink). 1857f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower *pos += 2; 1858f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 1859f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // Advance to the next node. 1860f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower *pos += 1; 1861f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1862f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1863f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower mutex_lock l(graph->mu); 186454e5000e0b980abe905900599c4493fadae34a15A. Unique TensorFlower while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) { 1865f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower Node* node = graph->graph.FindNodeId(*pos); 1866f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // FindNodeId() returns nullptr for nodes that have been deleted. 1867f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // We aren't currently allowing nodes to be deleted, but it is safer 1868f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // to still check. 1869a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower if (node != nullptr) return ToOperation(node); 1870f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower *pos += 1; 1871f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 1872f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1873f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // No more nodes. 1874f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower return nullptr; 1875f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1876f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1877f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlowervoid TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, 1878f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower TF_Status* status) { 1879f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower GraphDef def; 1880f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower { 1881f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower mutex_lock l(graph->mu); 1882f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower graph->graph.ToGraphDef(&def); 1883f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 188495ddc4ead6c6713ee4fcabd01ec107c3b2ca906eAsim Shankar status->status = MessageToBuffer(def, output_graph_def); 1885f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 1886f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 1887f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milnevoid TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, 1888f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne TF_Buffer* output_op_def, TF_Status* status) { 1889f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne const OpDef* op_def; 1890f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne { 1891f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne mutex_lock l(graph->mu); 1892f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); 1893f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne if (!status->status.ok()) return; 1894f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne } 1895f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne status->status = MessageToBuffer(*op_def, output_op_def); 1896f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne} 1897f5d3bf42b892ecfbde2ce9eb45f00b76473c824aSkye Wanderman-Milne 18987fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milnevoid TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, 18997fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne TF_Status* status) { 19007fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne VersionDef versions; 19017fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne { 19027fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne mutex_lock l(graph->mu); 19037fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne versions = graph->graph.versions(); 19047fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne } 19057fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne status->status = MessageToBuffer(versions, output_version_def); 19067fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne} 19077fd261602677d3c251fba05264a20318231deb76Skye Wanderman-Milne 190822221698a3ecd43024f84cd6c468ddd00955f920Asim ShankarTF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { 190922221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar return new TF_ImportGraphDefOptions; 191022221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar} 191122221698a3ecd43024f84cd6c468ddd00955f920Asim Shankarvoid TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { 191222221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar delete opts; 191322221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar} 191422221698a3ecd43024f84cd6c468ddd00955f920Asim Shankarvoid TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, 191522221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar const char* prefix) { 191622221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar opts->opts.prefix = prefix; 191722221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar} 191822221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar 1919c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milnevoid TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, 1920c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne unsigned char uniquify_names) { 1921c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne opts->opts.uniquify_names = uniquify_names; 1922c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne} 1923c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne 1924c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milnevoid TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, 1925c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne unsigned char uniquify_prefix) { 1926c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne opts->opts.uniquify_prefix = uniquify_prefix; 1927c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne} 1928c7778898eaf001c82744a8f4c71eb9a880a158f0Skye Wanderman-Milne 1929bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milnevoid TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, 1930bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne const char* src_name, 1931bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne int src_index, TF_Output dst) { 1932ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne opts->tensor_id_data.push_back(src_name); 1933ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& src_name_str = opts->tensor_id_data.back(); 1934ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne // We don't need to store dst's name in tensor_id_data, since `dst` must 1935ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne // outlive the ImportGraphDef call. 1936ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); 1937bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne} 1938bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne 19390386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manévoid TF_ImportGraphDefOptionsRemapControlDependency( 19400386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { 19410386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = 19420386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); 19430386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané} 19440386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 1945bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milneextern void TF_ImportGraphDefOptionsAddControlDependency( 1946bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne TF_ImportGraphDefOptions* opts, TF_Operation* oper) { 1947bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne opts->opts.control_dependencies.push_back(oper->node.name()); 1948bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne} 1949bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne 1950bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milnevoid TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, 1951bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne const char* oper_name, int index) { 1952ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne opts->tensor_id_data.push_back(oper_name); 1953ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& oper_name_str = opts->tensor_id_data.back(); 1954ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne opts->opts.return_tensors.emplace_back(oper_name_str, index); 1955bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne} 1956bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne 1957bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milneint TF_ImportGraphDefOptionsNumReturnOutputs( 1958bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne const TF_ImportGraphDefOptions* opts) { 1959bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne return opts->opts.return_tensors.size(); 1960bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne} 1961bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne 1962ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milnevoid TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, 1963ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const char* oper_name) { 1964ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne opts->opts.return_nodes.push_back(oper_name); 1965ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 1966ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 1967ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milneint TF_ImportGraphDefOptionsNumReturnOperations( 1968ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const TF_ImportGraphDefOptions* opts) { 1969ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne return opts->opts.return_nodes.size(); 1970ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 1971ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 1972ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milnevoid TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, 1973ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne int* num_outputs, 1974ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_Output** outputs) { 1975ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne *num_outputs = results->return_tensors.size(); 1976ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne *outputs = results->return_tensors.data(); 1977ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 1978ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 1979ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milnevoid TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, 1980ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne int* num_opers, 1981ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_Operation*** opers) { 1982ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne *num_opers = results->return_nodes.size(); 1983ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne *opers = results->return_nodes.data(); 1984ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 1985ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 1986968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milnevoid TF_ImportGraphDefResultsMissingUnusedInputMappings( 1987968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, 1988ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const char*** src_names, int** src_indexes) { 1989968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); 1990968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne *src_names = results->missing_unused_key_names.data(); 1991968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne *src_indexes = results->missing_unused_key_indexes.data(); 1992ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 1993ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 1994ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milnevoid TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { 1995ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne delete results; 1996ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 1997ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 199835ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseustatic void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, 199935ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu const TF_ImportGraphDefOptions* opts, 2000ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_ImportGraphDefResults* tf_results, 2001ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_Status* status) 200235ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { 200335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu const int last_node_id = graph->graph.num_node_ids(); 2004dc442f4ce2d3b11b56721337fe2b9e2282be93beSkye Wanderman-Milne tensorflow::ImportGraphDefResults results; 2005dc442f4ce2d3b11b56721337fe2b9e2282be93beSkye Wanderman-Milne status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, 2006dc442f4ce2d3b11b56721337fe2b9e2282be93beSkye Wanderman-Milne &graph->refiner, &results); 200735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (!status->status.ok()) return; 2008ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 2009ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne // Add new nodes to name_map 201035ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { 201135ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu auto* node = graph->graph.FindNodeId(i); 201235ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (node != nullptr) graph->name_map[node->name()] = node; 201335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu } 2014ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 2015ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne // Populate return_tensors 2016ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne DCHECK(tf_results->return_tensors.empty()); 2017ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne tf_results->return_tensors.resize(results.return_tensors.size()); 2018ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne for (int i = 0; i < results.return_tensors.size(); ++i) { 2019ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne tf_results->return_tensors[i].oper = 2020ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne ToOperation(results.return_tensors[i].first); 2021ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne tf_results->return_tensors[i].index = results.return_tensors[i].second; 2022ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne } 2023ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 2024ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne // Populate return_nodes 2025ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne DCHECK(tf_results->return_nodes.empty()); 2026ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne tf_results->return_nodes.resize(results.return_nodes.size()); 2027ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne for (int i = 0; i < results.return_nodes.size(); ++i) { 2028ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); 2029ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne } 2030ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 2031968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne // Populate missing unused map keys 2032968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne DCHECK(tf_results->missing_unused_key_names.empty()); 2033968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne DCHECK(tf_results->missing_unused_key_indexes.empty()); 2034968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne DCHECK(tf_results->missing_unused_key_names_data.empty()); 2035968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne 2036968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne size_t size = results.missing_unused_input_map_keys.size(); 2037968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne tf_results->missing_unused_key_names.resize(size); 2038968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne tf_results->missing_unused_key_indexes.resize(size); 2039968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne 2040968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne for (int i = 0; i < size; ++i) { 2041968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne TensorId id = results.missing_unused_input_map_keys[i]; 2042968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); 2043968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne tf_results->missing_unused_key_names[i] = 2044968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne tf_results->missing_unused_key_names_data.back().c_str(); 2045968da4bf2722b1303cc223e8342357d62c27dfc1Skye Wanderman-Milne tf_results->missing_unused_key_indexes[i] = id.second; 2046ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne } 2047ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne} 2048ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne 2049ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-MilneTF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( 2050ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_Graph* graph, const TF_Buffer* graph_def, 2051ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const TF_ImportGraphDefOptions* options, TF_Status* status) { 2052ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne GraphDef def; 2053ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne if (!def.ParseFromArray(graph_def->data, graph_def->length)) { 2054ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne status->status = InvalidArgument("Invalid GraphDef"); 2055ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne return nullptr; 2056bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne } 2057ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne auto results = new TF_ImportGraphDefResults(); 2058ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne mutex_lock l(graph->mu); 2059ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne GraphImportGraphDefLocked(graph, def, options, results, status); 2060ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne if (!status->status.ok()) { 2061ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne delete results; 2062ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne return nullptr; 2063ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne } 2064ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne return results; 206535ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu} 206635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 2067bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milnevoid TF_GraphImportGraphDefWithReturnOutputs( 2068bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne TF_Graph* graph, const TF_Buffer* graph_def, 2069ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, 2070bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne int num_return_outputs, TF_Status* status) { 2071ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne if (num_return_outputs != options->opts.return_tensors.size()) { 2072ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne status->status = InvalidArgument("Expected 'num_return_outputs' to be ", 2073ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne options->opts.return_tensors.size(), 2074ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne ", got ", num_return_outputs); 2075ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne return; 2076ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne } 2077ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne if (num_return_outputs > 0 && return_outputs == nullptr) { 2078ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne status->status = InvalidArgument( 2079ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne "'return_outputs' must be preallocated to length ", num_return_outputs); 2080ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne return; 2081ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne } 208222221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar GraphDef def; 208322221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar if (!def.ParseFromArray(graph_def->data, graph_def->length)) { 208422221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar status->status = InvalidArgument("Invalid GraphDef"); 208522221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar return; 208622221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar } 2087ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_ImportGraphDefResults results; 208822221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar mutex_lock l(graph->mu); 2089ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne GraphImportGraphDefLocked(graph, def, options, &results, status); 2090ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne DCHECK_EQ(results.return_tensors.size(), num_return_outputs); 2091ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne memcpy(return_outputs, results.return_tensors.data(), 2092ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne num_return_outputs * sizeof(TF_Output)); 2093bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne} 2094bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne 2095bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milnevoid TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, 2096bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne const TF_ImportGraphDefOptions* options, 2097bb71ec089658fb8a91423a7cf7195e5c900c2c98Skye Wanderman-Milne TF_Status* status) { 2098ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_ImportGraphDefResults* results = 2099ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); 2100ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_DeleteImportGraphDefResults(results); 210122221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar} 210222221698a3ecd43024f84cd6c468ddd00955f920Asim Shankar 2103661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne// While loop functions ------------------------------------------------------- 2104661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2105661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnenamespace { 21060fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 21070fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#ifndef __ANDROID__ 21080fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 21090fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Creates a placeholder representing an input to the cond or body graph. 21100fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// TODO(skyewm): remove these from final graph 2111661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnebool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, 2112661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* input, TF_Status* status) { 2113661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name); 2114661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input)); 2115661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne // TODO(skyewm): set placeholder shape 2116661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Operation* oper = TF_FinishOperation(desc, status); 2117661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (!status->status.ok()) return false; 2118661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne *input = {oper, 0}; 2119661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return true; 2120661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2121661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2122661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input 21230fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix` 21240fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// will be prepended to copied node names. `control_deps` are nodes in 21250fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// `dst_graph` that the copied `src_graph` nodes will have control dependencies 21260fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes 21270fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// in `dst_graph` will be returned. `return_nodes` must be non-null. 21280fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneStatus CopyGraph(Graph* src_graph, Graph* dst_graph, 21290fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::ShapeRefiner* dst_refiner, 21300fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const TF_Output* src_inputs, 21310fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const std::vector<tensorflow::Output>& dst_inputs, 2132ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne const string& prefix, 21330fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const std::vector<tensorflow::Operation>& control_deps, 21340fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const TF_Output* nodes_to_return, int nreturn_nodes, 21350fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output>* return_nodes) { 21360fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne DCHECK(return_nodes != nullptr); 2137661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne GraphDef gdef; 21380fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne src_graph->ToGraphDef(&gdef); 2139661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 21400fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::ImportGraphDefOptions opts; 21410fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne opts.prefix = prefix; 2142661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2143661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne for (int i = 0; i < dst_inputs.size(); ++i) { 21440fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne opts.input_map[ToTensorId(src_inputs[i])] = 21450fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index()); 2146661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 21470fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne opts.skip_mapped_nodes = true; 2148e1030858725b485b0f848cc27597b8e2c2d8383fIgor Ganichev 21490fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne for (const tensorflow::Operation& op : control_deps) { 21500fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne opts.control_dependencies.push_back(op.node()->name()); 21510fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne } 2152661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2153661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne for (int i = 0; i < nreturn_nodes; ++i) { 21540fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); 2155661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2156661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2157caced55cbc205a9423a480cae0bb9e7a9a10f3a1Asim Shankar // TODO(skyewm): change to OutputTensor 2158dc442f4ce2d3b11b56721337fe2b9e2282be93beSkye Wanderman-Milne tensorflow::ImportGraphDefResults results; 21590fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne TF_RETURN_IF_ERROR( 2160dc442f4ce2d3b11b56721337fe2b9e2282be93beSkye Wanderman-Milne ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); 21610fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 2162dc442f4ce2d3b11b56721337fe2b9e2282be93beSkye Wanderman-Milne for (const auto& pair : results.return_tensors) { 21630fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return_nodes->emplace_back(pair.first, pair.second); 21640fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne } 21650fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return Status::OK(); 2166661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2167661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2168661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnebool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { 2169661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (params.cond_graph == nullptr || params.body_graph == nullptr || 2170661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne params.cond_graph->parent == nullptr || 2171661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne params.cond_graph->parent != params.body_graph->parent || 2172661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne params.cond_graph->parent_inputs != params.body_graph->parent_inputs || 2173661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne params.ninputs <= 0 || params.cond_inputs == nullptr || 2174661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne params.body_inputs == nullptr || params.body_outputs == nullptr) { 2175661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne s->status = InvalidArgument( 2176661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne "TF_WhileParams must be created by successful TF_NewWhile() call"); 2177661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return false; 2178661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2179661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return true; 2180661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2181661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2182661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnebool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) { 2183661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (params.cond_output.oper == nullptr) { 2184661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set"); 2185661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return false; 2186661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2187661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne for (int i = 0; i < params.ninputs; ++i) { 2188661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (params.body_outputs[i].oper == nullptr) { 2189661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ", 2190661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne "field isn't set"); 2191661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return false; 2192661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2193661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2194661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (params.name == nullptr) { 2195661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne s->status = InvalidArgument("TF_WhileParams `name` field is null"); 2196661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return false; 2197661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2198661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return true; 2199661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2200661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 22010fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#endif // __ANDROID__ 22020fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 2203661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnevoid FreeWhileResources(const TF_WhileParams* params) { 2204661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_DeleteGraph(params->cond_graph); 2205661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_DeleteGraph(params->body_graph); 2206661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne delete[] params->cond_inputs; 2207661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne delete[] params->body_inputs; 2208661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne delete[] params->body_outputs; 2209661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2210661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2211661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-MilneTF_WhileParams EmptyWhileParams() { 2212661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return {0, nullptr, nullptr, {nullptr, 0}, 2213661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne nullptr, nullptr, nullptr, nullptr}; 2214661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2215661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2216661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} // namespace 2217661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2218661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-MilneTF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, 2219661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Status* status) { 22200fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#ifdef __ANDROID__ 22210fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne status->status = tensorflow::errors::Unimplemented( 22220fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne "Creating while loops is not supported in Android. File a bug at " 22230fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne "https://github.com/tensorflow/tensorflow/issues if this feature is " 22240fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne "important to you"); 22250fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return EmptyWhileParams(); 22260fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#else 2227661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (ninputs == 0) { 2228661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne status->status = 2229661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne InvalidArgument("TF_NewWhile() must be passed at least one input"); 2230661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return EmptyWhileParams(); 2231661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2232661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2233661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Graph* cond_graph = TF_NewGraph(); 2234661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Graph* body_graph = TF_NewGraph(); 2235661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne cond_graph->parent = g; 2236661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne cond_graph->parent_inputs = inputs; 2237661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne body_graph->parent = g; 2238661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne body_graph->parent_inputs = inputs; 2239661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2240661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* cond_inputs = new TF_Output[ninputs]; 2241661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output cond_output = {nullptr, -1}; 2242661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* body_inputs = new TF_Output[ninputs]; 2243661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* body_outputs = new TF_Output[ninputs]; 2244661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1}; 2245661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne const char* name = nullptr; 2246661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2247661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne for (int i = 0; i < ninputs; ++i) { 2248661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne // TODO(skyewm): prefix names with underscore (requires some plumbing) 2249661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(), 2250661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne &cond_inputs[i], status)) { 2251661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne break; 2252661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2253661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(), 2254661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne &body_inputs[i], status)) { 2255661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne break; 2256661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2257661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2258661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2259661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output, 2260661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne body_graph, body_inputs, body_outputs, name}; 2261661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2262661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (!status->status.ok()) { 2263661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne FreeWhileResources(¶ms); 2264661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return EmptyWhileParams(); 2265661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2266661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne return params; 22670fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#endif // __ANDROID__ 2268661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2269661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 22700fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#ifndef __ANDROID__ 2271661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnenamespace { 2272661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2273661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne// TODO(skyewm): make nodes in while loop unfetchable like in Python version 2274661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnevoid TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, 2275661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* outputs) { 2276661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (!ValidateInputWhileParams(*params, status)) return; 2277661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2278661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Graph* parent = params->cond_graph->parent; 2279661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* parent_inputs = params->cond_graph->parent_inputs; 22800fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne int num_loop_vars = params->ninputs; 2281661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2282661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne mutex_lock l(parent->mu); 2283661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 22840fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // 'cond_fn' copies the cond graph into the parent graph. 22850fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::ops::CondGraphBuilderFn cond_fn = 22860fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne [params, parent](const tensorflow::Scope& scope, 22870fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const std::vector<tensorflow::Output>& inputs, 22880fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::Output* output) { 22890fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne DCHECK_EQ(scope.graph(), &parent->graph); 22900fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output> cond_output; 22910fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne TF_RETURN_IF_ERROR(CopyGraph( 22920fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne ¶ms->cond_graph->graph, &parent->graph, &parent->refiner, 22930fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne params->cond_inputs, inputs, scope.impl()->name(), 22940fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne scope.impl()->control_deps(), ¶ms->cond_output, 22950fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne /* nreturn_nodes */ 1, &cond_output)); 22960fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne *output = cond_output[0]; 22970fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return Status::OK(); 22980fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne }; 22990fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 23000fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // 'body_fn' copies the body graph into the parent graph. 23010fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::ops::BodyGraphBuilderFn body_fn = 23020fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne [params, parent, num_loop_vars]( 23030fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const tensorflow::Scope& scope, 23040fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const std::vector<tensorflow::Output>& inputs, 23050fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output>* outputs) { 23060fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne DCHECK_EQ(scope.graph(), &parent->graph); 23070fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne TF_RETURN_IF_ERROR( 23080fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne CopyGraph(¶ms->body_graph->graph, &parent->graph, 23090fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne &parent->refiner, params->body_inputs, inputs, 23100fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne scope.impl()->name(), scope.impl()->control_deps(), 23110fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne params->body_outputs, num_loop_vars, outputs)); 23120fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne return Status::OK(); 23130fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne }; 23140fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 23150fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // Create the while loop using an internal scope. 23160fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::Scope scope = 23170fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne NewInternalScope(&parent->graph, &status->status, &parent->refiner) 23180fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne .NewSubScope(params->name); 23190fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 23200fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const int first_new_node_id = parent->graph.num_node_ids(); 23210fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 23220fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne tensorflow::OutputList loop_outputs; 23230fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne status->status = tensorflow::ops::BuildWhileLoop( 23240fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn, 23250fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne body_fn, params->name, &loop_outputs); 23260fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 23270fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // Update name_map with newly-created ops. 23280fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns 23290fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // a bad status. Once we fix this, we may want to return early instead of 23300fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // executing the following code. 23310fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) { 23320fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne Node* new_node = parent->graph.FindNodeId(i); 23330fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne if (new_node == nullptr) continue; 23340fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne parent->name_map[new_node->name()] = new_node; 23350fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne } 23360fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 23370fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne // Populate 'outputs'. 23380fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne DCHECK_LE(loop_outputs.size(), num_loop_vars); 23390fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne for (int i = 0; i < loop_outputs.size(); ++i) { 23400fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()}; 2341661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne } 2342661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2343661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2344661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} // namespace 23450fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#endif // __ANDROID__ 2346661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2347661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnevoid TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, 2348661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_Output* outputs) { 23490fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#ifdef __ANDROID__ 23500fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne status->status = tensorflow::errors::Unimplemented( 23510fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne "Creating while loops is not supported in Android. File a bug at " 23520fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne "https://github.com/tensorflow/tensorflow/issues if this feature is " 23530fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne "important to you"); 23540fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#else 2355661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne // If it appears the caller created or modified `params`, don't free resources 2356661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne if (!ValidateConstWhileParams(*params, status)) return; 2357661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne TF_FinishWhileHelper(params, status, outputs); 2358661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne FreeWhileResources(params); 23590fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#endif // __ANDROID__ 2360661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne} 2361661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2362661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milnevoid TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } 2363661058e52460b36417573e2c5a73de9a8b9e5edbSkye Wanderman-Milne 2364908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumarvoid TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, 2365908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar TF_Output* dx, TF_Status* status, TF_Output* dy) { 2366908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar#ifdef __ANDROID__ 2367908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar status->status = tensorflow::errors::Unimplemented( 2368908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar "Adding gradients is not supported in Android. File a bug at " 2369908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar "https://github.com/tensorflow/tensorflow/issues if this feature is " 2370908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar "important to you"); 2371908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar#else 23720fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny); 23730fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx); 2374908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar std::vector<tensorflow::Output> dy_arg; 2375908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 2376908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar { 2377908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar // We need to hold on to the lock while we have a scope that uses TF_Graph. 2378908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar mutex_lock graph_lock(g->mu); 2379908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 23800fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const int first_new_node_id = g->graph.num_node_ids(); 2381908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 2382908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar tensorflow::Scope scope = 2383c311449c76433be7ead189c0acc7202d2e884513Skye Wanderman-Milne NewInternalScope(&g->graph, &status->status, &g->refiner) 2384c311449c76433be7ead189c0acc7202d2e884513Skye Wanderman-Milne .NewSubScope("gradients"); 2385908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 2386908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar if (dx != nullptr) { 23870fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny); 2388908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar status->status = 2389908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); 2390908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar } else { 2391908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg); 2392908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar } 2393908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 2394908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar // Update g->name_map with the name_map from the scope, which will contain 2395908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar // the new gradient ops. 23960fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { 2397908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar Node* n = g->graph.FindNodeId(i); 2398908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar if (n == nullptr) continue; 2399908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar g->name_map[n->name()] = n; 2400908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar } 2401908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar } 2402908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 2403908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar // Unpack the results from grad_outputs_arg. 2404908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar TFOutputsFromOutputs(dy_arg, dy); 2405908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar#endif // __ANDROID__ 2406908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar} 2407908d5b6ede6ae829dff138a873eec397ef434cd6Suharsh Sivakumar 2408e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankar// TF_Session functions ---------------------------------------------- 2409f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 241022651083406ca01ac9d481e3367a3510d25f88cdAsim ShankarTF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) 24112bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) { 241222651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar if (s->LocalDeviceManager(&device_mgr).ok()) { 241322651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar devices = device_mgr->ListDevices(); 241422651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar } 241522651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar} 241622651083406ca01ac9d481e3367a3510d25f88cdAsim Shankar 2417e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim ShankarTF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, 2418e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankar TF_Status* status) { 2419f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower Session* session; 2420f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = NewSession(opt->options, &session); 2421f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (status->status.ok()) { 2422cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev TF_Session* new_session = new TF_Session(session, graph); 2423f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (graph != nullptr) { 2424f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower mutex_lock l(graph->mu); 2425cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev graph->sessions[new_session] = Status::OK(); 2426f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 2427cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev return new_session; 2428f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 2429f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower DCHECK_EQ(nullptr, session); 2430d83074847ebfe8871188f1f9f1e84ab0451f59e6A. Unique TensorFlower return nullptr; 2431f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 2432f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 2433f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 243435ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan HseuTF_Session* TF_LoadSessionFromSavedModel( 243535ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu const TF_SessionOptions* session_options, const TF_Buffer* run_options, 243635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu const char* export_dir, const char* const* tags, int tags_len, 243735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { 2438d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that 2439d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar// the tensorflow/cc/saved_model:loader build target is Android friendly. 2440d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar#ifdef __ANDROID__ 2441d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar status->status = tensorflow::errors::Unimplemented( 2442d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar "Loading a SavedModel is not supported in Android. File a bug at " 2443d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar "https://github.com/tensorflow/tensorflow/issues if this feature is " 2444d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar "important to you"); 2445d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar return nullptr; 2446d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar#else 244735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu mutex_lock l(graph->mu); 244835ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (!graph->name_map.empty()) { 244935ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu status->status = InvalidArgument("Graph is non-empty."); 245035ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu return nullptr; 245135ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu } 245235ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 245335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu RunOptions run_options_proto; 2454d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar if (run_options != nullptr && !run_options_proto.ParseFromArray( 2455d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar run_options->data, run_options->length)) { 245635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu status->status = InvalidArgument("Unparseable RunOptions proto"); 245735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu return nullptr; 245835ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu } 245935ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 2460ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::unordered_set<string> tag_set; 246135ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu for (int i = 0; i < tags_len; i++) { 2462ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne tag_set.insert(string(tags[i])); 246335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu } 246435ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 246535ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu tensorflow::SavedModelBundle bundle; 246635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu status->status = 246735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu tensorflow::LoadSavedModel(session_options->options, run_options_proto, 246835ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu export_dir, tag_set, &bundle); 246935ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (!status->status.ok()) return nullptr; 247035ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 247135ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session 247235ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu // extends using GraphDefs. The Graph instance is different, but equivalent 247335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu // to the one used to create the session. 247435ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu // 247535ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu // TODO(jhseu): When Session is modified to take Graphs instead of 247635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu // GraphDefs, return the Graph generated in LoadSavedModel(). 247735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); 2478ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne TF_ImportGraphDefResults results; 247935ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), 2480ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne import_opts, &results, status); 248135ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu TF_DeleteImportGraphDefOptions(import_opts); 248235ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (TF_GetCode(status) != TF_OK) return nullptr; 248335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 248435ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (meta_graph_def != nullptr) { 248535ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); 248635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu if (!status->status.ok()) return nullptr; 248735ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu } 248835ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 248935ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu TF_Session* session = new TF_Session(bundle.session.release(), graph); 249035ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 2491cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev graph->sessions[session] = Status::OK(); 249235ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu session->last_num_graph_nodes = graph->graph.num_node_ids(); 249335ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu return session; 2494e8f2aad0c0502fde74fc629f5b13f04d5d206700Asim Shankar#endif // __ANDROID__ 2495d18f4c54cd5e6a009c4ab1bc01c3a5432bafa6aeAsim Shankar} 249635ea93ba205d44319f408a7d90dc385fd86d6b56Jonathan Hseu 2497e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankarvoid TF_CloseSession(TF_Session* s, TF_Status* status) { 2498f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = s->session->Close(); 2499f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 2500f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 2501e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankarvoid TF_DeleteSession(TF_Session* s, TF_Status* status) { 25022bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar status->status = Status::OK(); 25032bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar TF_Graph* const graph = s->graph; 25042bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar if (graph != nullptr) { 25052bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar graph->mu.lock(); 2506cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev graph->sessions.erase(s); 2507cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev const bool del = graph->delete_requested && graph->sessions.empty(); 25082bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar graph->mu.unlock(); 25092bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar if (del) delete graph; 2510f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 25112bae08e2afe62afbf83064ae7d9e5d2aa2ef9ee6Asim Shankar delete s->session; 2512f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower delete s; 2513f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 2514f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 25157cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower// TODO(josh11b,mrry): Change Session to be able to use a Graph* 25167cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower// directly, instead of requiring us to serialize to a GraphDef and 25177cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower// call Session::Extend(). 2518e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankarstatic bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { 2519f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (session->graph != nullptr) { 2520f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower mutex_lock session_lock(session->mu); 2521f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower session->graph->mu.lock(); 2522f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const Graph& graph = session->graph->graph; 2523cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev 2524cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev status->status = session->graph->sessions[session]; 2525cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev if (!status->status.ok()) { 2526cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev session->graph->mu.unlock(); 2527cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev return false; 2528cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev } 2529cb5a63d8d2b6e049a0a128ba47560f842497db8bIgor Ganichev 2530f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower const auto num_nodes = graph.num_node_ids(); 2531f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (session->last_num_graph_nodes < num_nodes) { 2532f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist status->status = tensorflow::ValidateNoCycles(session->graph->graph); 2533f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist if (!status->status.ok()) { 2534f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist session->graph->mu.unlock(); 2535f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist return false; 2536f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist } 2537f8d4b5eae59124c9a4b01a4cc1097e1f0006137aOlivia Nordquist 2538f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower GraphDef graph_def; 25399501c4104125fb8c2c2d2e837fc2dd8a24034d52A. Unique TensorFlower *graph_def.mutable_versions() = graph.versions(); 2540f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // Fill graph_def with nodes with ids in the range 2541f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // [session->last_num_graph_nodes, num_nodes), that is the nodes 2542f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // added since the last TF_SessionRun() call. 2543f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { 2544f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower Node* const node = graph.FindNodeId(id); 2545f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (node != nullptr && node->IsOp()) { 2546f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower NodeDef* const node_def = graph_def.add_node(); 2547f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower *node_def = node->def(); 2548f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 2549f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 255092f403f05d776873868f7e5c7b52d559ba7f0efaA. Unique TensorFlower *graph_def.mutable_library() = graph.flib_def().ToProto(); 2551f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower session->graph->mu.unlock(); 2552f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower status->status = session->session->Extend(graph_def); 2553f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower if (!status->status.ok()) { 2554f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // Contract is we always delete input_values[i]. 25557cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return false; 2556f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 2557f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // Note: session->session is not modified if Extend() fails, so 2558f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower // we only set last_num_graph_nodes if it succeeds. 2559f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower session->last_num_graph_nodes = num_nodes; 2560f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } else { 2561f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower session->graph->mu.unlock(); 2562f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 2563f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 25647cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return true; 25657cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower} 25667cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2567e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankarvoid TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, 25688f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu const TF_Output* inputs, TF_Tensor* const* input_values, 25698f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu int ninputs, const TF_Output* outputs, 25707cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Tensor** output_values, int noutputs, 2571a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const TF_Operation* const* target_opers, int ntargets, 25727cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Buffer* run_metadata, TF_Status* status) { 25737cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // TODO(josh11b,mrry): Change Session to be able to use a Graph* 25747cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // directly, instead of requiring us to serialize to a GraphDef and 25757cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // call Session::Extend(). 25767cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower if (!ExtendSessionGraphHelper(session, status)) { 25777cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return; 25787cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 25797cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 25807cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Run_Setup(noutputs, output_values, status); 25817cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 25828f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu // Convert from TF_Output and TF_Tensor to a string and Tensor. 2583ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<std::pair<string, Tensor>> input_pairs(ninputs); 25841f0c5119a0230c5160d45496175b9256f097e144Asim Shankar if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; 25857cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 25868f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu input_pairs[i].first = OutputName(inputs[i]); 25877cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 25887cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 25898f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu // Convert from TF_Output to string names. 2590ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> output_names(noutputs); 25917cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 25928f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu output_names[i] = OutputName(outputs[i]); 25937cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 25947cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2595a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower // Convert from TF_Operation* to string names. 2596ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> target_names(ntargets); 25977cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ntargets; ++i) { 2598a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_names[i] = target_opers[i]->node.name(); 25997cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26007cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 26017cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // Actually run. 26027cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Run_Helper(session->session, nullptr, run_options, input_pairs, 26037cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower output_names, output_values, target_names, run_metadata, 26047cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower status); 26057cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower} 26067cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 26078f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseuvoid TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, 26088f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu int ninputs, const TF_Output* outputs, int noutputs, 2609a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const TF_Operation* const* target_opers, int ntargets, 26107cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower const char** handle, TF_Status* status) { 2611348812b4a56ca460d2006abf1032f9ad9c86a084A. Unique TensorFlower *handle = nullptr; 2612348812b4a56ca460d2006abf1032f9ad9c86a084A. Unique TensorFlower 26137cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower if (!ExtendSessionGraphHelper(session, status)) { 26147cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return; 26157cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26167cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2617ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> input_names(ninputs); 26187cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 26198f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu input_names[i] = OutputName(inputs[i]); 26207cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26217cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2622ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> output_names(noutputs); 26237cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 26248f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu output_names[i] = OutputName(outputs[i]); 26257cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26267cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2627ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> target_names(ntargets); 26287cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower for (int i = 0; i < ntargets; ++i) { 2629a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_names[i] = target_opers[i]->node.name(); 26307cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26317cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2632ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne string new_handle; 26337cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower status->status = session->session->PRunSetup(input_names, output_names, 26347cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower target_names, &new_handle); 26357cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower if (status->status.ok()) { 26367cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower char* buf = new char[new_handle.size() + 1]; 26377cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower memcpy(buf, new_handle.c_str(), new_handle.size() + 1); 26387cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower *handle = buf; 26397cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26407cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower} 26417cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 26420c5d9bef6e36fbe8f57e1e9faa0fede411cbdc4fAsim Shankarvoid TF_DeletePRunHandle(const char* handle) { 26430c5d9bef6e36fbe8f57e1e9faa0fede411cbdc4fAsim Shankar delete[] handle; 26440c5d9bef6e36fbe8f57e1e9faa0fede411cbdc4fAsim Shankar // TODO(suharshs): Free up any resources held by the partial run state. 26450c5d9bef6e36fbe8f57e1e9faa0fede411cbdc4fAsim Shankar} 26460c5d9bef6e36fbe8f57e1e9faa0fede411cbdc4fAsim Shankar 2647e580e721cc1a205cb4b7afe64bc6af4d775a4851Asim Shankarvoid TF_SessionPRun(TF_Session* session, const char* handle, 26488f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu const TF_Output* inputs, TF_Tensor* const* input_values, 26498f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu int ninputs, const TF_Output* outputs, 26507cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Tensor** output_values, int noutputs, 2651a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower const TF_Operation* const* target_opers, int ntargets, 26527cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Status* status) { 26537cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // TODO(josh11b,mrry): Change Session to be able to use a Graph* 26547cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // directly, instead of requiring us to serialize to a GraphDef and 26557cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower // call Session::Extend(). 26567cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower if (!ExtendSessionGraphHelper(session, status)) { 26577cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower return; 26587cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 2659f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 26607cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Run_Setup(noutputs, output_values, status); 26617cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 26628f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu // Convert from TF_Output and TF_Tensor to a string and Tensor. 2663ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<std::pair<string, Tensor>> input_pairs(ninputs); 26641f0c5119a0230c5160d45496175b9256f097e144Asim Shankar if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; 2665f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < ninputs; ++i) { 26668f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu input_pairs[i].first = OutputName(inputs[i]); 2667f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 26687cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 26698f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu // Convert from TF_Output to string names. 2670ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> output_names(noutputs); 2671f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < noutputs; ++i) { 26728f6cb22c675b5c0b553334a8f04daef462905d69Jonathan Hseu output_names[i] = OutputName(outputs[i]); 2673f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower } 26747cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 2675a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower // Convert from TF_Operation* to string names. 2676ce0238198052358d102ca7786ad9be60a5e76d28Skye Wanderman-Milne std::vector<string> target_names(ntargets); 2677f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower for (int i = 0; i < ntargets; ++i) { 2678a823862343e1b4e987bd0d279e94a126db7f2158A. Unique TensorFlower target_names[i] = target_opers[i]->node.name(); 26797cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower } 26807cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower 26817cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, 26827cc7b56b8a605f52d717173122a382cadd611793A. Unique TensorFlower output_values, target_names, nullptr, status); 2683f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} 2684f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower 2685d064a47543f51ff5a62927a76bb0fb0862d05558Anna RTF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { 2686d064a47543f51ff5a62927a76bb0fb0862d05558Anna R tensorflow::OpList op_list; 2687d064a47543f51ff5a62927a76bb0fb0862d05558Anna R if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { 2688d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = InvalidArgument("Unparseable OpList"); 2689d064a47543f51ff5a62927a76bb0fb0862d05558Anna R return nullptr; 2690d064a47543f51ff5a62927a76bb0fb0862d05558Anna R } 2691d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = Status::OK(); 2692d064a47543f51ff5a62927a76bb0fb0862d05558Anna R return new TF_ApiDefMap(op_list); 2693d064a47543f51ff5a62927a76bb0fb0862d05558Anna R} 2694d064a47543f51ff5a62927a76bb0fb0862d05558Anna R 2695d064a47543f51ff5a62927a76bb0fb0862d05558Anna Rvoid TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } 2696d064a47543f51ff5a62927a76bb0fb0862d05558Anna R 2697d064a47543f51ff5a62927a76bb0fb0862d05558Anna Rvoid TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, 2698d064a47543f51ff5a62927a76bb0fb0862d05558Anna R size_t text_len, TF_Status* status) { 2699d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#ifdef __ANDROID__ 2700d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = tensorflow::errors::Unimplemented( 2701d064a47543f51ff5a62927a76bb0fb0862d05558Anna R "ApiDefMap is not supported in Android."); 2702d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#else 2703d064a47543f51ff5a62927a76bb0fb0862d05558Anna R mutex_lock l(api_def_map->lock); 2704d064a47543f51ff5a62927a76bb0fb0862d05558Anna R if (api_def_map->update_docs_called) { 2705d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = FailedPrecondition( 2706d064a47543f51ff5a62927a76bb0fb0862d05558Anna R "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " 2707d064a47543f51ff5a62927a76bb0fb0862d05558Anna R "called."); 2708d064a47543f51ff5a62927a76bb0fb0862d05558Anna R return; 2709d064a47543f51ff5a62927a76bb0fb0862d05558Anna R } 2710d064a47543f51ff5a62927a76bb0fb0862d05558Anna R string api_def_text(text, text_len); 2711d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); 2712d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#endif // __ANDROID__ 2713d064a47543f51ff5a62927a76bb0fb0862d05558Anna R} 2714d064a47543f51ff5a62927a76bb0fb0862d05558Anna R 2715d064a47543f51ff5a62927a76bb0fb0862d05558Anna RTF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, 2716d064a47543f51ff5a62927a76bb0fb0862d05558Anna R size_t name_len, TF_Status* status) { 2717d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#ifdef __ANDROID__ 2718d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = tensorflow::errors::Unimplemented( 2719d064a47543f51ff5a62927a76bb0fb0862d05558Anna R "ApiDefMap is not supported in Android."); 2720d064a47543f51ff5a62927a76bb0fb0862d05558Anna R return nullptr; 2721d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#else 2722d064a47543f51ff5a62927a76bb0fb0862d05558Anna R mutex_lock l(api_def_map->lock); 2723d064a47543f51ff5a62927a76bb0fb0862d05558Anna R if (!api_def_map->update_docs_called) { 2724d064a47543f51ff5a62927a76bb0fb0862d05558Anna R api_def_map->api_def_map.UpdateDocs(); 2725d064a47543f51ff5a62927a76bb0fb0862d05558Anna R api_def_map->update_docs_called = true; 2726d064a47543f51ff5a62927a76bb0fb0862d05558Anna R } 2727d064a47543f51ff5a62927a76bb0fb0862d05558Anna R string name_str(name, name_len); 2728d064a47543f51ff5a62927a76bb0fb0862d05558Anna R const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); 2729d064a47543f51ff5a62927a76bb0fb0862d05558Anna R 2730d064a47543f51ff5a62927a76bb0fb0862d05558Anna R TF_Buffer* ret = TF_NewBuffer(); 2731d064a47543f51ff5a62927a76bb0fb0862d05558Anna R status->status = MessageToBuffer(*api_def, ret); 2732d064a47543f51ff5a62927a76bb0fb0862d05558Anna R return ret; 2733d064a47543f51ff5a62927a76bb0fb0862d05558Anna R#endif // __ANDROID__ 2734d064a47543f51ff5a62927a76bb0fb0862d05558Anna R} 2735f6f36020a7e9c2da563cf4053cf653ef00fbd4f7A. Unique TensorFlower} // end extern "C" 2736