1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory>
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector>
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device.h"
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/common_runtime/device_factory.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/allocator.h"
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/device_base.h"
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/graph.pb.h"
276ada43366663210beb0159b8c1a67b26ebfe6cb7Geoffrey Irving#include "tensorflow/core/framework/node_def.pb.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h"
298464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower#include "tensorflow/core/framework/resource_mgr.h"
303ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h"
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_testutil.h"
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h"
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.pb.h"
343ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h"
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/core/status_test_util.h"
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h"
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/inlined_vector.h"
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/stl_util.h"
393ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/env.h"
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h"
41564abcc02f98ae83f8ae9969a7546b510efbbb94Josh Levenberg#include "tensorflow/core/platform/macros.h"
42af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h"
43c8eaac926c929e07ac8db69f67803a2223ff2d93Josh Levenberg#include "tensorflow/core/platform/test.h"
443ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/session_options.h"
46d9cfc64a2ddf05c0b093c8fb6704c67452ee3ea0Vijay Vasudevan#include "tensorflow/core/public/version.h"
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/util/tensor_slice_reader_cache.h"
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
49cd4f5840a834e5380536b6f04a768408c6eebf3dA. Unique TensorFlowernamespace tensorflow {
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace test {
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
526dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlowerinline void SetOutputAttrs(OpKernelContext::Params* params,
536dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower                           std::vector<AllocatorAttributes>* attrs) {
546dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  attrs->clear();
556dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
566dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    AllocatorAttributes attr;
576dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    const bool on_host =
586dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower        (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
596dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    attr.set_on_host(on_host);
606dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    attrs->push_back(attr);
616dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  }
626dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower  params->output_attr_array = gtl::vector_as_array(attrs);
636dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower}
646dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace test
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Helpful functions to test operators.
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// This class will eventually be replaced / heavily modified
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// to use the BrainClient interface.
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass OpsTestBase : public ::testing::Test {
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
73d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  OpsTestBase()
74d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
75d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower        device_type_(DEVICE_CPU) {
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK(device_.get()) << "Could not create CPU device";
77d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    allocator_ = device_->GetAllocator(AllocatorAttributes());
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~OpsTestBase() override {
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    gtl::STLDeleteElements(&tensors_);
82d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    gtl::STLDeleteElements(&managed_outputs_);
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    context_.reset(nullptr);
84ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.reset(nullptr);
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
87f6f07b027512a53e0c80fbd357ac8ef9ae7bf870A. Unique TensorFlower  // Allow kernel unit tests to run on GPU
88d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device);
89f6f07b027512a53e0c80fbd357ac8ef9ae7bf870A. Unique TensorFlower
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Clients can manipulate the underlying NodeDef via this accessor.
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  NodeDef* node_def() { return &node_def_; }
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Initializes an operator that takes in 'input_types' as input
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // and output types as output.
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the status of initialization.
999e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden  Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
1009e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden
1019e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden  // Only use this directly if you have a deprecated op that you need to test.
1029e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden  Status InitOpWithGraphVersion(int graph_def_version) {
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Status status;
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
1059e8541bc2a64d1e3d7b96160c445c822a06afd0aPete Warden                             node_def_, graph_def_version, &status);
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    if (kernel_ != nullptr) input_types_ = kernel_->input_types();
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return status;
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Adds an input for every element described by the shape.
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // 'input_mapping' maps an index (0...NumElements(shape)) to a
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // value.
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // TODO(vrv): Replace with something like a BrainClient Feed.
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <typename T>
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
117d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Like AddInput but takes in an explicit arrayslice of data.
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <typename T>
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void AddInputFromArray(const TensorShape& shape,
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                         const gtl::ArraySlice<T>& data) {
124d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
12736357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner  // Convenience function to add an input and populate it with the elements from
12836357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner  // an initializer list converting the types as needed.
12936357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner  template <typename T, typename SrcType>
13036357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner  void AddInputFromList(const TensorShape& shape,
13136357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner                        std::initializer_list<SrcType> data) {
132d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
13336357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner  }
13436357e7e1127873165694a38e3a989df4e0b6ffeBenoit Steiner
1351ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower  // Adds a Resource type as input. If <container> is empty, uses the default
1361ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower  // container name.
1371ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower  template <typename T>
1381ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower  void AddResourceInput(const string& container, const string& name,
1391ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower                        T* resource) {
1401ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    CHECK_GT(input_types_.size(), inputs_.size())
1411ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower        << "Adding more inputs than types; perhaps you need to call MakeOp";
1421ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    ResourceMgr* rm = device_->resource_manager();
1431ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    EXPECT_TRUE(
1441ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower        rm->Create(container == "" ? rm->default_container() : container, name,
1451ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower                   resource)
1461ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower            .ok());
1471ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    TypeIndex type_index = MakeTypeIndex<T>();
1481ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    ResourceHandle handle;
1491ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    handle.set_device(device_->name());
1501ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    handle.set_container(container);
1511ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    handle.set_name(name);
1521ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    handle.set_hash_code(type_index.hash_code());
1531ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    handle.set_maybe_type_name(type_index.name());
154d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    Tensor* input = new Tensor(allocator(), DT_RESOURCE, TensorShape({}));
1551ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    input->scalar<ResourceHandle>()() = handle;
1561ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    tensors_.push_back(input);
1571ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower    inputs_.push_back({nullptr, input});
1581ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower  }
1591ee98618dd74404e26ee0202f27e05cb3dcef5c3A. Unique TensorFlower
160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Runs an operation producing 'num_outputs' outputs.
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the context's status after running the operation.
163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  Status RunOpKernel() {
164ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    // Make sure the old OpKernelContext is deleted before the Params
165ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    // it was using.
166ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    context_.reset(nullptr);
167ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower
168ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.reset(new OpKernelContext::Params);
169ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->device = device_.get();
170ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->frame_iter = FrameAndIter(0, 0);
171ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->inputs = &inputs_;
172ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->op_kernel = kernel_.get();
1738464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower    step_container_.reset(new ScopedStepContainer(0, [](const string&) {}));
1748464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower    params_->step_container = step_container_.get();
1756dbfb95100b73ad26ebebb9be9c0429dc0cece8aA. Unique TensorFlower    std::vector<AllocatorAttributes> attrs;
176ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    test::SetOutputAttrs(params_.get(), &attrs);
177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
178ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
179cb324446acbdf0d3d2129904361cf0bcbe53e852Pete Warden    params_.get()->resource_manager = device_.get()->resource_manager();
180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
181ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower    context_.reset(new OpKernelContext(params_.get()));
182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    device_->Compute(kernel_.get(), context_.get());
183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return context_->status();
184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the tensor input for 'input_index'.
187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES: 0 <= input_index < context_->num_inputs()
189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const Tensor& GetInput(int input_index) const {
190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_LT(input_index, context_->num_inputs());
191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK(!IsRefType(context_->input_dtype(input_index)));
192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return context_->input(input_index);
193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  TensorValue mutable_input(int input_index) {
196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_LT(input_index, inputs_.size());
197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return inputs_[input_index];
198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns the tensor output for 'output_index'.
200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES: 0 <= output_index < context_->num_outputs()
202d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  Tensor* GetOutput(int output_index);
203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
204d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  Allocator* allocator() { return allocator_; }
205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const DataTypeVector& output_types() const { return kernel_->output_types(); }
207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
208d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower private:
209d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  Tensor* AddInput(DataType dtype, const TensorShape& shape) {
210d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    CHECK_GT(input_types_.size(), inputs_.size())
211d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower        << "Adding more inputs than types; perhaps you need to call MakeOp";
212d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    bool is_ref = IsRefType(input_types_[inputs_.size()]);
213d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    Tensor* input = new Tensor(allocator(), dtype, shape);
214d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    tensors_.push_back(input);
215d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    if (is_ref) {
216d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), dtype);
217d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      inputs_.push_back({&lock_for_refs_, input});
218d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    } else {
219d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      CHECK_EQ(input_types_[inputs_.size()], dtype);
220d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      inputs_.push_back({nullptr, input});
221d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    }
222d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    return input;
223d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  }
224d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected:
226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<Device> device_;
227d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  // The device allocator, or the managed_allocator_ below if running on GPU.
228d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  Allocator* allocator_;
229f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
230f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<OpKernel> kernel_;
2318464d0516d3f2c857723aa4c42fb1947e6c480d8A. Unique TensorFlower  std::unique_ptr<ScopedStepContainer> step_container_;
232f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  NodeDef node_def_;
233f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DataTypeVector input_types_;
234f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DeviceType device_type_;
235f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
236f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
238f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  gtl::InlinedVector<TensorValue, 4> inputs_;
239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Owns Tensors.
240f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::vector<Tensor*> tensors_;
241d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  // Copies of the outputs in unified memory (host and device accessible).
242d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  std::vector<Tensor*> managed_outputs_;
243f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
244ff8522de343a90813fc4e5cbb249e308c1819f1dA. Unique TensorFlower  std::unique_ptr<OpKernelContext::Params> params_;
245f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<OpKernelContext> context_;
246d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  // Unified memory allocator, only used when running on GPU.
247d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  std::unique_ptr<Allocator> managed_allocator_;
248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
251f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
252f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
253f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace tensorflow
254f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
255f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
256