1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/resource_op_kernel.h"
17
18#include <memory>
19
20#include "tensorflow/core/framework/allocator.h"
21#include "tensorflow/core/framework/node_def_builder.h"
22#include "tensorflow/core/framework/node_def_util.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/types.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/status_test_util.h"
27#include "tensorflow/core/lib/strings/strcat.h"
28#include "tensorflow/core/platform/mutex.h"
29#include "tensorflow/core/platform/test.h"
30#include "tensorflow/core/platform/thread_annotations.h"
31#include "tensorflow/core/public/version.h"
32
33namespace tensorflow {
34namespace {
35
36// Stub DeviceBase subclass which only returns allocators.
37class StubDevice : public DeviceBase {
38 public:
39  StubDevice() : DeviceBase(nullptr) {}
40
41  Allocator* GetAllocator(AllocatorAttributes) override {
42    return cpu_allocator();
43  }
44};
45
46// Stub resource for testing resource op kernel.
47class StubResource : public ResourceBase {
48 public:
49  string DebugString() override { return ""; }
50  int code;
51};
52
53class StubResourceOpKernel : public ResourceOpKernel<StubResource> {
54 public:
55  using ResourceOpKernel::ResourceOpKernel;
56
57  StubResource* resource() LOCKS_EXCLUDED(mu_) {
58    mutex_lock lock(mu_);
59    return resource_;
60  }
61
62 private:
63  Status CreateResource(StubResource** resource) override {
64    *resource = CHECK_NOTNULL(new StubResource);
65    return GetNodeAttr(def(), "code", &(*resource)->code);
66  }
67
68  Status VerifyResource(StubResource* resource) override {
69    int code;
70    TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code));
71    if (code != resource->code) {
72      return errors::InvalidArgument("stub has code ", resource->code,
73                                     " but requested code ", code);
74    }
75    return Status::OK();
76  }
77};
78
79REGISTER_OP("StubResourceOp")
80    .Attr("code: int")
81    .Attr("container: string = ''")
82    .Attr("shared_name: string = ''")
83    .Output("output: Ref(string)");
84
85REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU),
86                        StubResourceOpKernel);
87
88class ResourceOpKernelTest : public ::testing::Test {
89 protected:
90  std::unique_ptr<StubResourceOpKernel> CreateOp(int code,
91                                                 const string& shared_name) {
92    NodeDef node_def;
93    TF_CHECK_OK(
94        NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp")
95            .Attr("code", code)
96            .Attr("shared_name", shared_name)
97            .Finalize(&node_def));
98    Status status;
99    std::unique_ptr<OpKernel> op(CreateOpKernel(
100        DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()),
101        node_def, TF_GRAPH_DEF_VERSION, &status));
102    TF_EXPECT_OK(status) << status;
103    EXPECT_TRUE(op != nullptr);
104
105    // Downcast to StubResourceOpKernel to call resource() later.
106    std::unique_ptr<StubResourceOpKernel> resource_op(
107        dynamic_cast<StubResourceOpKernel*>(op.get()));
108    EXPECT_TRUE(resource_op != nullptr);
109    if (resource_op != nullptr) {
110      op.release();
111    }
112    return resource_op;
113  }
114
115  Status RunOpKernel(OpKernel* op) {
116    OpKernelContext::Params params;
117
118    params.device = &device_;
119    params.resource_manager = &mgr_;
120    params.op_kernel = op;
121
122    OpKernelContext context(&params);
123    op->Compute(&context);
124    return context.status();
125  }
126
127  StubDevice device_;
128  ResourceMgr mgr_;
129  int count_ = 0;
130};
131
132TEST_F(ResourceOpKernelTest, PrivateResource) {
133  // Empty shared_name means private resource.
134  const int code = -100;
135  auto op = CreateOp(code, "");
136  ASSERT_TRUE(op != nullptr);
137  TF_EXPECT_OK(RunOpKernel(op.get()));
138
139  // Default non-shared name provided from ContainerInfo.
140  const string key = "_0_" + op->name();
141
142  StubResource* resource;
143  TF_ASSERT_OK(
144      mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource));
145  EXPECT_EQ(op->resource(), resource);  // Check resource identity.
146  EXPECT_EQ(code, resource->code);      // Check resource stored information.
147  resource->Unref();
148
149  // Destroy the op kernel. Expect the resource to be released.
150  op = nullptr;
151  Status s =
152      mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource);
153
154  EXPECT_FALSE(s.ok());
155}
156
157TEST_F(ResourceOpKernelTest, SharedResource) {
158  const string shared_name = "shared_stub";
159  const int code = -201;
160  auto op = CreateOp(code, shared_name);
161  ASSERT_TRUE(op != nullptr);
162  TF_EXPECT_OK(RunOpKernel(op.get()));
163
164  StubResource* resource;
165  TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name,
166                                         &resource));
167  EXPECT_EQ(op->resource(), resource);  // Check resource identity.
168  EXPECT_EQ(code, resource->code);      // Check resource stored information.
169  resource->Unref();
170
171  // Destroy the op kernel. Expect the resource not to be released.
172  op = nullptr;
173  TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name,
174                                         &resource));
175  resource->Unref();
176}
177
178TEST_F(ResourceOpKernelTest, LookupShared) {
179  auto op1 = CreateOp(-333, "shared_stub");
180  auto op2 = CreateOp(-333, "shared_stub");
181  ASSERT_TRUE(op1 != nullptr);
182  ASSERT_TRUE(op2 != nullptr);
183
184  TF_EXPECT_OK(RunOpKernel(op1.get()));
185  TF_EXPECT_OK(RunOpKernel(op2.get()));
186  EXPECT_EQ(op1->resource(), op2->resource());
187}
188
189TEST_F(ResourceOpKernelTest, VerifyResource) {
190  auto op1 = CreateOp(-444, "shared_stub");
191  auto op2 = CreateOp(0, "shared_stub");  // Different resource code.
192  ASSERT_TRUE(op1 != nullptr);
193  ASSERT_TRUE(op2 != nullptr);
194
195  TF_EXPECT_OK(RunOpKernel(op1.get()));
196  EXPECT_FALSE(RunOpKernel(op2.get()).ok());
197  EXPECT_TRUE(op1->resource() != nullptr);
198  EXPECT_TRUE(op2->resource() == nullptr);
199}
200
201}  // namespace
202}  // namespace tensorflow
203