1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory> 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector> 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device.h" 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device_factory.h" 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/allocator.h" 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/device_base.h" 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/graph.pb.h" 276ada43366663210beb0159b8c1a67b26ebfe6cb7Geoffrey Irving#include "tensorflow/core/framework/node_def.pb.h" 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h" 298464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower#include "tensorflow/core/framework/resource_mgr.h" 303ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h" 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_testutil.h" 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h" 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.pb.h" 343ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h" 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/status_test_util.h" 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h" 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/inlined_vector.h" 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/stl_util.h" 393ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/env.h" 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h" 41564abcc02f98ae83f8ae9969a7546b510efbbb94Josh Levenberg#include "tensorflow/core/platform/macros.h" 42af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h" 43c8eaac926c929e07ac8db69f67803a2223ff2d93Josh Levenberg#include "tensorflow/core/platform/test.h" 443ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h" 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/session_options.h" 46d9cfc64a2ddf05c0b093c8fb6704c67452ee3ea0Vijay Vasudevan#include "tensorflow/core/public/version.h" 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/util/tensor_slice_reader_cache.h" 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 49cd4f5840a834e5380536b6f04a768408c6eebf3dA. Unique TensorFlowernamespace tensorflow { 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace test { 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 526dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlowerinline void SetOutputAttrs(OpKernelContext::Params* params, 536dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower std::vector<AllocatorAttributes>* attrs) { 546dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower attrs->clear(); 556dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower for (int index = 0; index < params->op_kernel->num_outputs(); index++) { 566dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower AllocatorAttributes attr; 576dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower const bool on_host = 586dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower (params->op_kernel->output_memory_types()[index] == HOST_MEMORY); 596dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower attr.set_on_host(on_host); 606dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower attrs->push_back(attr); 616dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower } 626dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower params->output_attr_array = gtl::vector_as_array(attrs); 636dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower} 646dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace test 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Helpful functions to test operators. 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// This class will eventually be replaced / heavily modified 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// to use the BrainClient interface. 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass OpsTestBase : public ::testing::Test { 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 73d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower OpsTestBase() 74d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")), 75d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower device_type_(DEVICE_CPU) { 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK(device_.get()) << "Could not create CPU device"; 77d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower allocator_ = device_->GetAllocator(AllocatorAttributes()); 78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~OpsTestBase() override { 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::STLDeleteElements(&tensors_); 82d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower gtl::STLDeleteElements(&managed_outputs_); 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context_.reset(nullptr); 84ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.reset(nullptr); 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 87f6f07b027512a53e0c80fbd357ac8ef9ae7bf870A. Unique TensorFlower // Allow kernel unit tests to run on GPU 88d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device); 89f6f07b027512a53e0c80fbd357ac8ef9ae7bf870A. Unique TensorFlower 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); } 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Clients can manipulate the underlying NodeDef via this accessor. 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur NodeDef* node_def() { return &node_def_; } 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Initializes an operator that takes in 'input_types' as input 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // and output types as output. 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the status of initialization. 999e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); } 1009e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden 1019e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden // Only use this directly if you have a deprecated op that you need to test. 1029e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden Status InitOpWithGraphVersion(int graph_def_version) { 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Status status; 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), 1059e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden node_def_, graph_def_version, &status); 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (kernel_ != nullptr) input_types_ = kernel_->input_types(); 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return status; 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Adds an input for every element described by the shape. 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 'input_mapping' maps an index (0...NumElements(shape)) to a 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // value. 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // TODO(vrv): Replace with something like a BrainClient Feed. 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <typename T> 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) { 117d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping); 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Like AddInput but takes in an explicit arrayslice of data. 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <typename T> 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void AddInputFromArray(const TensorShape& shape, 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const gtl::ArraySlice<T>& data) { 124d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data); 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 12736357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner // Convenience function to add an input and populate it with the elements from 12836357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner // an initializer list converting the types as needed. 12936357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner template <typename T, typename SrcType> 13036357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner void AddInputFromList(const TensorShape& shape, 13136357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner std::initializer_list<SrcType> data) { 132d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data); 13336357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner } 13436357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner 1351ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower // Adds a Resource type as input. If <container> is empty, uses the default 1361ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower // container name. 1371ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower template <typename T> 1381ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower void AddResourceInput(const string& container, const string& name, 1391ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower T* resource) { 1401ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower CHECK_GT(input_types_.size(), inputs_.size()) 1411ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower << "Adding more inputs than types; perhaps you need to call MakeOp"; 1421ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower ResourceMgr* rm = device_->resource_manager(); 1431ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower EXPECT_TRUE( 1441ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower rm->Create(container == "" ? rm->default_container() : container, name, 1451ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower resource) 1461ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower .ok()); 1471ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower TypeIndex type_index = MakeTypeIndex<T>(); 1481ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower ResourceHandle handle; 1491ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower handle.set_device(device_->name()); 1501ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower handle.set_container(container); 1511ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower handle.set_name(name); 1521ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower handle.set_hash_code(type_index.hash_code()); 1531ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower handle.set_maybe_type_name(type_index.name()); 154d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Tensor* input = new Tensor(allocator(), DT_RESOURCE, TensorShape({})); 1551ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower input->scalar<ResourceHandle>()() = handle; 1561ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower tensors_.push_back(input); 1571ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower inputs_.push_back({nullptr, input}); 1581ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower } 1591ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower 160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Runs an operation producing 'num_outputs' outputs. 161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the context's status after running the operation. 163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Status RunOpKernel() { 164ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower // Make sure the old OpKernelContext is deleted before the Params 165ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower // it was using. 166ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower context_.reset(nullptr); 167ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower 168ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.reset(new OpKernelContext::Params); 169ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->device = device_.get(); 170ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->frame_iter = FrameAndIter(0, 0); 171ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->inputs = &inputs_; 172ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->op_kernel = kernel_.get(); 1738464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower step_container_.reset(new ScopedStepContainer(0, [](const string&) {})); 1748464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower params_->step_container = step_container_.get(); 1756dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower std::vector<AllocatorAttributes> attrs; 176ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower test::SetOutputAttrs(params_.get(), &attrs); 177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; 178ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->slice_reader_cache = &slice_reader_cache_wrapper; 179cb324446acbdf0d3d2129904361cf0bcbe53e852Pete Warden params_.get()->resource_manager = device_.get()->resource_manager(); 180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 181ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower context_.reset(new OpKernelContext(params_.get())); 182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur device_->Compute(kernel_.get(), context_.get()); 183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return context_->status(); 184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the tensor input for 'input_index'. 187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES: 0 <= input_index < context_->num_inputs() 189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& GetInput(int input_index) const { 190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_LT(input_index, context_->num_inputs()); 191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK(!IsRefType(context_->input_dtype(input_index))); 192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return context_->input(input_index); 193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorValue mutable_input(int input_index) { 196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_LT(input_index, inputs_.size()); 197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return inputs_[input_index]; 198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the tensor output for 'output_index'. 200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES: 0 <= output_index < context_->num_outputs() 202d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Tensor* GetOutput(int output_index); 203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 204d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Allocator* allocator() { return allocator_; } 205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const DataTypeVector& output_types() const { return kernel_->output_types(); } 207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 208d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower private: 209d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Tensor* AddInput(DataType dtype, const TensorShape& shape) { 210d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower CHECK_GT(input_types_.size(), inputs_.size()) 211d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower << "Adding more inputs than types; perhaps you need to call MakeOp"; 212d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower bool is_ref = IsRefType(input_types_[inputs_.size()]); 213d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Tensor* input = new Tensor(allocator(), dtype, shape); 214d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower tensors_.push_back(input); 215d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower if (is_ref) { 216d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), dtype); 217d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower inputs_.push_back({&lock_for_refs_, input}); 218d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } else { 219d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower CHECK_EQ(input_types_[inputs_.size()], dtype); 220d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower inputs_.push_back({nullptr, input}); 221d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } 222d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower return input; 223d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } 224d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected: 226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<Device> device_; 227d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower // The device allocator, or the managed_allocator_ below if running on GPU. 228d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Allocator* allocator_; 229f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 230f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<OpKernel> kernel_; 2318464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower std::unique_ptr<ScopedStepContainer> step_container_; 232f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur NodeDef node_def_; 233f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DataTypeVector input_types_; 234f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceType device_type_; 235f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 236f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur mutex lock_for_refs_; // Used as the Mutex for inputs added as refs 237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 238f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::InlinedVector<TensorValue, 4> inputs_; 239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Owns Tensors. 240f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::vector<Tensor*> tensors_; 241d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower // Copies of the outputs in unified memory (host and device accessible). 242d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower std::vector<Tensor*> managed_outputs_; 243f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 244ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower std::unique_ptr<OpKernelContext::Params> params_; 245f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<OpKernelContext> context_; 246d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower // Unified memory allocator, only used when running on GPU. 247d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower std::unique_ptr<Allocator> managed_allocator_; 248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase); 251f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 252f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 253f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace tensorflow 254f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 255f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ 256