ops_testutil.h revision 9e8541bc2a64d1e3d7b96160c445c822a06afd0a
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur/* Copyright 2015 Google Inc. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory> 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector> 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device.h" 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device_factory.h" 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/allocator.h" 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/device_base.h" 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/graph.pb.h" 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h" 283ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h" 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_testutil.h" 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h" 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.pb.h" 323ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h" 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/status_test_util.h" 34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h" 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/inlined_vector.h" 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/stl_util.h" 373ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/env.h" 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h" 39564abcc02f98ae83f8ae9969a7546b510efbbb94Josh Levenberg#include "tensorflow/core/platform/macros.h" 40af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h" 41c8eaac926c929e07ac8db69f67803a2223ff2d93Josh Levenberg#include "tensorflow/core/platform/test.h" 423ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h" 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/session_options.h" 44d9cfc64a2ddf05c0b093c8fb6704c67452ee3ea0Vijay Vasudevan#include "tensorflow/core/public/version.h" 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/util/tensor_slice_reader_cache.h" 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace test { 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Return a NodeDef with the specified name/op/inputs. 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurNodeDef Node(const string& name, const string& op, 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const std::vector<string>& inputs); 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 556dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlowerinline void SetOutputAttrs(OpKernelContext::Params* params, 566dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower std::vector<AllocatorAttributes>* attrs) { 576dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower attrs->clear(); 586dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower for (int index = 0; index < params->op_kernel->num_outputs(); index++) { 596dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower AllocatorAttributes attr; 606dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower const bool on_host = 616dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower (params->op_kernel->output_memory_types()[index] == HOST_MEMORY); 626dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower attr.set_on_host(on_host); 636dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower attrs->push_back(attr); 646dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower } 656dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower params->output_attr_array = gtl::vector_as_array(attrs); 666dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower} 676dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace test 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Helpful functions to test operators. 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// This class will eventually be replaced / heavily modified 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// to use the BrainClient interface. 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass OpsTestBase : public ::testing::Test { 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OpsTestBase() : device_type_(DEVICE_CPU) { 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur device_.reset( 78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK(device_.get()) << "Could not create CPU device"; 80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~OpsTestBase() override { 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::STLDeleteElements(&tensors_); 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context_.reset(nullptr); 85ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.reset(nullptr); 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); } 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Clients can manipulate the underlying NodeDef via this accessor. 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur NodeDef* node_def() { return &node_def_; } 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Initializes an operator that takes in 'input_types' as input 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // and output types as output. 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the status of initialization. 979e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); } 989e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden 999e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden // Only use this directly if you have a deprecated op that you need to test. 1009e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden Status InitOpWithGraphVersion(int graph_def_version) { 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Status status; 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), 1039e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden node_def_, graph_def_version, &status); 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (kernel_ != nullptr) input_types_ = kernel_->input_types(); 105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return status; 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Adds an input for every element described by the shape. 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 'input_mapping' maps an index (0...NumElements(shape)) to a 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // value. 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // TODO(vrv): Replace with something like a BrainClient Feed. 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <typename T> 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) { 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_GT(input_types_.size(), inputs_.size()) 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur << "Adding more inputs than types; perhaps you need to call MakeOp"; 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur bool is_ref = IsRefType(input_types_[inputs_.size()]); 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DataTypeToEnum<T>::v(), shape); 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur test::FillFn(input, input_mapping); 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensors_.push_back(input); 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (is_ref) { 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DataTypeToEnum<T>::v()); 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur inputs_.push_back({&lock_for_refs_, input}); 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } else { 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v()); 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur inputs_.push_back({nullptr, input}); 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Like AddInput but takes in an explicit arrayslice of data. 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <typename T> 134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void AddInputFromArray(const TensorShape& shape, 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const gtl::ArraySlice<T>& data) { 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_GT(input_types_.size(), inputs_.size()) 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur << "Adding more inputs than types; perhaps you need to call MakeOp"; 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur bool is_ref = IsRefType(input_types_[inputs_.size()]); 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DataTypeToEnum<T>::v(), shape); 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur test::FillValues<T>(input, data); 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensors_.push_back(input); 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if (is_ref) { 144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DataTypeToEnum<T>::v()); 146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur inputs_.push_back({&lock_for_refs_, input}); 147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } else { 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v()); 149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur inputs_.push_back({nullptr, input}); 150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Runs an operation producing 'num_outputs' outputs. 154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the context's status after running the operation. 156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Status RunOpKernel() { 157ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower // Make sure the old OpKernelContext is deleted before the Params 158ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower // it was using. 159ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower context_.reset(nullptr); 160ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower 161ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.reset(new OpKernelContext::Params); 162ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->device = device_.get(); 163ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->frame_iter = FrameAndIter(0, 0); 164ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->inputs = &inputs_; 165ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->op_kernel = kernel_.get(); 1666dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower std::vector<AllocatorAttributes> attrs; 167ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower test::SetOutputAttrs(params_.get(), &attrs); 168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; 169ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower params_.get()->slice_reader_cache = &slice_reader_cache_wrapper; 170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 171ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower context_.reset(new OpKernelContext(params_.get())); 172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur device_->Compute(kernel_.get(), context_.get()); 173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return context_->status(); 174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the tensor input for 'input_index'. 177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES: 0 <= input_index < context_->num_inputs() 179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& GetInput(int input_index) const { 180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_LT(input_index, context_->num_inputs()); 181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK(!IsRefType(context_->input_dtype(input_index))); 182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return context_->input(input_index); 183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorValue mutable_input(int input_index) { 186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_LT(input_index, inputs_.size()); 187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return inputs_[input_index]; 188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns the tensor output for 'output_index'. 190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES: 0 <= output_index < context_->num_outputs() 192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Tensor* GetOutput(int output_index) { 193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_LT(output_index, context_->num_outputs()); 194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return context_->mutable_output(output_index); 195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Allocator* allocator() { 198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return device_->GetAllocator(AllocatorAttributes()); 199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const DataTypeVector& output_types() const { return kernel_->output_types(); } 202f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected: 204f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<Device> device_; 205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<OpKernel> kernel_; 207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur NodeDef node_def_; 208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DataTypeVector input_types_; 209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceType device_type_; 210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur mutex lock_for_refs_; // Used as the Mutex for inputs added as refs 212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::InlinedVector<TensorValue, 4> inputs_; 214f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Owns Tensors. 215f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::vector<Tensor*> tensors_; 216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 217ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower std::unique_ptr<OpKernelContext::Params> params_; 218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<OpKernelContext> context_; 219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 221f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase); 222f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 223f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 224f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace tensorflow 225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ 227