1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/resource_op_kernel.h" 17 18#include <memory> 19 20#include "tensorflow/core/framework/allocator.h" 21#include "tensorflow/core/framework/node_def_builder.h" 22#include "tensorflow/core/framework/node_def_util.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/types.h" 25#include "tensorflow/core/lib/core/errors.h" 26#include "tensorflow/core/lib/core/status_test_util.h" 27#include "tensorflow/core/lib/strings/strcat.h" 28#include "tensorflow/core/platform/mutex.h" 29#include "tensorflow/core/platform/test.h" 30#include "tensorflow/core/platform/thread_annotations.h" 31#include "tensorflow/core/public/version.h" 32 33namespace tensorflow { 34namespace { 35 36// Stub DeviceBase subclass which only returns allocators. 37class StubDevice : public DeviceBase { 38 public: 39 StubDevice() : DeviceBase(nullptr) {} 40 41 Allocator* GetAllocator(AllocatorAttributes) override { 42 return cpu_allocator(); 43 } 44}; 45 46// Stub resource for testing resource op kernel. 47class StubResource : public ResourceBase { 48 public: 49 string DebugString() override { return ""; } 50 int code; 51}; 52 53class StubResourceOpKernel : public ResourceOpKernel<StubResource> { 54 public: 55 using ResourceOpKernel::ResourceOpKernel; 56 57 StubResource* resource() LOCKS_EXCLUDED(mu_) { 58 mutex_lock lock(mu_); 59 return resource_; 60 } 61 62 private: 63 Status CreateResource(StubResource** resource) override { 64 *resource = CHECK_NOTNULL(new StubResource); 65 return GetNodeAttr(def(), "code", &(*resource)->code); 66 } 67 68 Status VerifyResource(StubResource* resource) override { 69 int code; 70 TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code)); 71 if (code != resource->code) { 72 return errors::InvalidArgument("stub has code ", resource->code, 73 " but requested code ", code); 74 } 75 return Status::OK(); 76 } 77}; 78 79REGISTER_OP("StubResourceOp") 80 .Attr("code: int") 81 .Attr("container: string = ''") 82 .Attr("shared_name: string = ''") 83 .Output("output: Ref(string)"); 84 85REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU), 86 StubResourceOpKernel); 87 88class ResourceOpKernelTest : public ::testing::Test { 89 protected: 90 std::unique_ptr<StubResourceOpKernel> CreateOp(int code, 91 const string& shared_name) { 92 NodeDef node_def; 93 TF_CHECK_OK( 94 NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp") 95 .Attr("code", code) 96 .Attr("shared_name", shared_name) 97 .Finalize(&node_def)); 98 Status status; 99 std::unique_ptr<OpKernel> op(CreateOpKernel( 100 DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()), 101 node_def, TF_GRAPH_DEF_VERSION, &status)); 102 TF_EXPECT_OK(status) << status; 103 EXPECT_TRUE(op != nullptr); 104 105 // Downcast to StubResourceOpKernel to call resource() later. 106 std::unique_ptr<StubResourceOpKernel> resource_op( 107 dynamic_cast<StubResourceOpKernel*>(op.get())); 108 EXPECT_TRUE(resource_op != nullptr); 109 if (resource_op != nullptr) { 110 op.release(); 111 } 112 return resource_op; 113 } 114 115 Status RunOpKernel(OpKernel* op) { 116 OpKernelContext::Params params; 117 118 params.device = &device_; 119 params.resource_manager = &mgr_; 120 params.op_kernel = op; 121 122 OpKernelContext context(¶ms); 123 op->Compute(&context); 124 return context.status(); 125 } 126 127 StubDevice device_; 128 ResourceMgr mgr_; 129 int count_ = 0; 130}; 131 132TEST_F(ResourceOpKernelTest, PrivateResource) { 133 // Empty shared_name means private resource. 134 const int code = -100; 135 auto op = CreateOp(code, ""); 136 ASSERT_TRUE(op != nullptr); 137 TF_EXPECT_OK(RunOpKernel(op.get())); 138 139 // Default non-shared name provided from ContainerInfo. 140 const string key = "_0_" + op->name(); 141 142 StubResource* resource; 143 TF_ASSERT_OK( 144 mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource)); 145 EXPECT_EQ(op->resource(), resource); // Check resource identity. 146 EXPECT_EQ(code, resource->code); // Check resource stored information. 147 resource->Unref(); 148 149 // Destroy the op kernel. Expect the resource to be released. 150 op = nullptr; 151 Status s = 152 mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource); 153 154 EXPECT_FALSE(s.ok()); 155} 156 157TEST_F(ResourceOpKernelTest, SharedResource) { 158 const string shared_name = "shared_stub"; 159 const int code = -201; 160 auto op = CreateOp(code, shared_name); 161 ASSERT_TRUE(op != nullptr); 162 TF_EXPECT_OK(RunOpKernel(op.get())); 163 164 StubResource* resource; 165 TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name, 166 &resource)); 167 EXPECT_EQ(op->resource(), resource); // Check resource identity. 168 EXPECT_EQ(code, resource->code); // Check resource stored information. 169 resource->Unref(); 170 171 // Destroy the op kernel. Expect the resource not to be released. 172 op = nullptr; 173 TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name, 174 &resource)); 175 resource->Unref(); 176} 177 178TEST_F(ResourceOpKernelTest, LookupShared) { 179 auto op1 = CreateOp(-333, "shared_stub"); 180 auto op2 = CreateOp(-333, "shared_stub"); 181 ASSERT_TRUE(op1 != nullptr); 182 ASSERT_TRUE(op2 != nullptr); 183 184 TF_EXPECT_OK(RunOpKernel(op1.get())); 185 TF_EXPECT_OK(RunOpKernel(op2.get())); 186 EXPECT_EQ(op1->resource(), op2->resource()); 187} 188 189TEST_F(ResourceOpKernelTest, VerifyResource) { 190 auto op1 = CreateOp(-444, "shared_stub"); 191 auto op2 = CreateOp(0, "shared_stub"); // Different resource code. 192 ASSERT_TRUE(op1 != nullptr); 193 ASSERT_TRUE(op2 != nullptr); 194 195 TF_EXPECT_OK(RunOpKernel(op1.get())); 196 EXPECT_FALSE(RunOpKernel(op2.get()).ok()); 197 EXPECT_TRUE(op1->resource() != nullptr); 198 EXPECT_TRUE(op2->resource() == nullptr); 199} 200 201} // namespace 202} // namespace tensorflow 203