1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// See docs in ../ops/data_flow_ops.cc. 17 18#include <limits.h> 19#include <vector> 20 21#include "tensorflow/core/common_runtime/device.h" 22#include "tensorflow/core/framework/device_base.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/register_types.h" 25#include "tensorflow/core/framework/tensor.h" 26#include "tensorflow/core/framework/tensor_shape.h" 27#include "tensorflow/core/framework/types.h" 28#include "tensorflow/core/lib/core/errors.h" 29#include "tensorflow/core/lib/gtl/map_util.h" 30#include "tensorflow/core/platform/logging.h" 31#include "tensorflow/core/platform/macros.h" 32#include "tensorflow/core/platform/mutex.h" 33#include "tensorflow/core/platform/thread_annotations.h" 34#include "tensorflow/core/platform/types.h" 35 36namespace tensorflow { 37 38class GetSessionHandleOp : public OpKernel { 39 public: 40 explicit GetSessionHandleOp(OpKernelConstruction* context) 41 : OpKernel(context) {} 42 43 void Compute(OpKernelContext* ctx) override { 44 const Tensor& val = ctx->input(0); 45 int64 id = ctx->session_state()->GetNewId(); 46 TensorStore::TensorAndKey tk{val, id, requested_device()}; 47 OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk)); 48 49 Tensor* handle = nullptr; 50 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); 51 if (ctx->expected_output_dtype(0) == DT_RESOURCE) { 52 ResourceHandle resource_handle = MakeResourceHandle<Tensor>( 53 ctx, SessionState::kTensorHandleResourceTypeName, 54 tk.GetHandle(name())); 55 resource_handle.set_maybe_type_name( 56 SessionState::kTensorHandleResourceTypeName); 57 handle->scalar<ResourceHandle>()() = resource_handle; 58 } else { 59 // Legacy behavior in V1. 60 handle->flat<string>().setConstant(tk.GetHandle(name())); 61 } 62 } 63 64 TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp); 65}; 66 67REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU), 68 GetSessionHandleOp); 69REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU), 70 GetSessionHandleOp); 71 72#define REGISTER_GPU_KERNEL(type) \ 73 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \ 74 .Device(DEVICE_GPU) \ 75 .HostMemory("handle") \ 76 .TypeConstraint<type>("T"), \ 77 GetSessionHandleOp) \ 78 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \ 79 .Device(DEVICE_GPU) \ 80 .HostMemory("handle") \ 81 .TypeConstraint<type>("T"), \ 82 GetSessionHandleOp) 83 84TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); 85REGISTER_GPU_KERNEL(bool); 86#undef REGISTER_GPU_KERNEL 87 88#ifdef TENSORFLOW_USE_SYCL 89#define REGISTER_SYCL_KERNEL(type) \ 90 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \ 91 .Device(DEVICE_SYCL) \ 92 .HostMemory("handle") \ 93 .TypeConstraint<type>("T"), \ 94 GetSessionHandleOp) \ 95 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \ 96 .Device(DEVICE_SYCL) \ 97 .HostMemory("handle") \ 98 .TypeConstraint<type>("T"), \ 99 GetSessionHandleOp) 100 101TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); 102REGISTER_SYCL_KERNEL(bool); 103#undef REGISTER_SYCL_KERNEL 104#endif // TENSORFLOW_USE_SYCL 105 106class GetSessionTensorOp : public OpKernel { 107 public: 108 explicit GetSessionTensorOp(OpKernelConstruction* context) 109 : OpKernel(context) {} 110 111 void Compute(OpKernelContext* ctx) override { 112 const Tensor& handle = ctx->input(0); 113 const string& name = handle.scalar<string>()(); 114 Tensor val; 115 OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val)); 116 ctx->set_output(0, val); 117 } 118 119 TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp); 120}; 121 122REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU), 123 GetSessionTensorOp); 124 125#define REGISTER_GPU_KERNEL(type) \ 126 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \ 127 .Device(DEVICE_GPU) \ 128 .HostMemory("handle") \ 129 .TypeConstraint<type>("dtype"), \ 130 GetSessionTensorOp) 131 132TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); 133REGISTER_GPU_KERNEL(bool); 134#undef REGISTER_GPU_KERNEL 135 136#ifdef TENSORFLOW_USE_SYCL 137#define REGISTER_SYCL_KERNEL(type) \ 138 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \ 139 .Device(DEVICE_SYCL) \ 140 .HostMemory("handle") \ 141 .TypeConstraint<type>("dtype"), \ 142 GetSessionTensorOp) 143 144TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); 145REGISTER_SYCL_KERNEL(bool); 146#undef REGISTER_SYCL_KERNEL 147#endif // TENSORFLOW_USE_SYCL 148 149class DeleteSessionTensorOp : public OpKernel { 150 public: 151 explicit DeleteSessionTensorOp(OpKernelConstruction* context) 152 : OpKernel(context) {} 153 154 void Compute(OpKernelContext* ctx) override { 155 const Tensor& handle = ctx->input(0); 156 const string& name = handle.scalar<string>()(); 157 OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name)); 158 } 159 160 TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp); 161}; 162 163REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU), 164 DeleteSessionTensorOp); 165REGISTER_KERNEL_BUILDER( 166 Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"), 167 DeleteSessionTensorOp); 168 169#ifdef TENSORFLOW_USE_SYCL 170REGISTER_KERNEL_BUILDER( 171 Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"), 172 DeleteSessionTensorOp); 173#endif // TENSORFLOW_USE_SYCL 174} // namespace tensorflow 175