ops_testutil.h revision 9e8541bc2a64d1e3d7b96160c445c822a06afd0a
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur/* Copyright 2015 Google Inc. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory>
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector>
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device.h"
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device_factory.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/allocator.h"
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/device_base.h"
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/graph.pb.h"
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h"
283ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h"
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_testutil.h"
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h"
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.pb.h"
323ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h"
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/status_test_util.h"
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h"
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/inlined_vector.h"
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/stl_util.h"
373ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/env.h"
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h"
39564abcc02f98ae83f8ae9969a7546b510efbbb94Josh Levenberg#include "tensorflow/core/platform/macros.h"
40af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h"
41c8eaac926c929e07ac8db69f67803a2223ff2d93Josh Levenberg#include "tensorflow/core/platform/test.h"
423ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/session_options.h"
44d9cfc64a2ddf05c0b093c8fb6704c67452ee3ea0Vijay Vasudevan#include "tensorflow/core/public/version.h"
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/util/tensor_slice_reader_cache.h"
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace test {
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Return a NodeDef with the specified name/op/inputs.
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurNodeDef Node(const string& name, const string& op,
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur             const std::vector<string>& inputs);
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
556dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlowerinline void SetOutputAttrs(OpKernelContext::Params* params,
566dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower                           std::vector<AllocatorAttributes>* attrs) {
576dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  attrs->clear();
586dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
596dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    AllocatorAttributes attr;
606dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    const bool on_host =
616dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower        (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
626dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    attr.set_on_host(on_host);
636dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    attrs->push_back(attr);
646dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  }
656dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  params->output_attr_array = gtl::vector_as_array(attrs);
666dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower}
676dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace test
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Helpful functions to test operators.
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// This class will eventually be replaced / heavily modified
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// to use the BrainClient interface.
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass OpsTestBase : public ::testing::Test {
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  OpsTestBase() : device_type_(DEVICE_CPU) {
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    device_.reset(
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK(device_.get()) << "Could not create CPU device";
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~OpsTestBase() override {
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    gtl::STLDeleteElements(&tensors_);
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    context_.reset(nullptr);
85ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.reset(nullptr);
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Clients can manipulate the underlying NodeDef via this accessor.
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  NodeDef* node_def() { return &node_def_; }
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Initializes an operator that takes in 'input_types' as input
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // and output types as output.
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the status of initialization.
979e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden  Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
989e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden
999e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden  // Only use this directly if you have a deprecated op that you need to test.
1009e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden  Status InitOpWithGraphVersion(int graph_def_version) {
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Status status;
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
1039e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden                             node_def_, graph_def_version, &status);
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    if (kernel_ != nullptr) input_types_ = kernel_->input_types();
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return status;
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Adds an input for every element described by the shape.
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // 'input_mapping' maps an index (0...NumElements(shape)) to a
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // value.
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // TODO(vrv): Replace with something like a BrainClient Feed.
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <typename T>
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_GT(input_types_.size(), inputs_.size())
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        << "Adding more inputs than types; perhaps you need to call MakeOp";
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    bool is_ref = IsRefType(input_types_[inputs_.size()]);
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                               DataTypeToEnum<T>::v(), shape);
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    test::FillFn(input, input_mapping);
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    tensors_.push_back(input);
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    if (is_ref) {
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               DataTypeToEnum<T>::v());
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      inputs_.push_back({&lock_for_refs_, input});
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    } else {
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      inputs_.push_back({nullptr, input});
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Like AddInput but takes in an explicit arrayslice of data.
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <typename T>
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void AddInputFromArray(const TensorShape& shape,
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                         const gtl::ArraySlice<T>& data) {
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_GT(input_types_.size(), inputs_.size())
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        << "Adding more inputs than types; perhaps you need to call MakeOp";
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    bool is_ref = IsRefType(input_types_[inputs_.size()]);
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                               DataTypeToEnum<T>::v(), shape);
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    test::FillValues<T>(input, data);
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    tensors_.push_back(input);
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    if (is_ref) {
144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               DataTypeToEnum<T>::v());
146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      inputs_.push_back({&lock_for_refs_, input});
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    } else {
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      inputs_.push_back({nullptr, input});
150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Runs an operation producing 'num_outputs' outputs.
154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the context's status after running the operation.
156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  Status RunOpKernel() {
157ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    // Make sure the old OpKernelContext is deleted before the Params
158ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    // it was using.
159ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    context_.reset(nullptr);
160ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower
161ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.reset(new OpKernelContext::Params);
162ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->device = device_.get();
163ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->frame_iter = FrameAndIter(0, 0);
164ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->inputs = &inputs_;
165ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->op_kernel = kernel_.get();
1666dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    std::vector<AllocatorAttributes> attrs;
167ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    test::SetOutputAttrs(params_.get(), &attrs);
168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
169ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
171ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    context_.reset(new OpKernelContext(params_.get()));
172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    device_->Compute(kernel_.get(), context_.get());
173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return context_->status();
174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the tensor input for 'input_index'.
177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES: 0 <= input_index < context_->num_inputs()
179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const Tensor& GetInput(int input_index) const {
180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_LT(input_index, context_->num_inputs());
181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK(!IsRefType(context_->input_dtype(input_index)));
182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return context_->input(input_index);
183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  TensorValue mutable_input(int input_index) {
186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_LT(input_index, inputs_.size());
187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return inputs_[input_index];
188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the tensor output for 'output_index'.
190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES: 0 <= output_index < context_->num_outputs()
192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  Tensor* GetOutput(int output_index) {
193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_LT(output_index, context_->num_outputs());
194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return context_->mutable_output(output_index);
195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  Allocator* allocator() {
198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return device_->GetAllocator(AllocatorAttributes());
199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const DataTypeVector& output_types() const { return kernel_->output_types(); }
202f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected:
204f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<Device> device_;
205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<OpKernel> kernel_;
207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  NodeDef node_def_;
208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DataTypeVector input_types_;
209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DeviceType device_type_;
210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  gtl::InlinedVector<TensorValue, 4> inputs_;
214f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Owns Tensors.
215f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::vector<Tensor*> tensors_;
216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
217ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower  std::unique_ptr<OpKernelContext::Params> params_;
218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<OpKernelContext> context_;
219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
221f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
222f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
223f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
224f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace tensorflow
225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
227