1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/data_flow_ops.cc.
17
18#include <limits.h>
19#include <vector>
20
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/framework/device_base.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/gtl/map_util.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/platform/mutex.h"
33#include "tensorflow/core/platform/thread_annotations.h"
34#include "tensorflow/core/platform/types.h"
35
36namespace tensorflow {
37
38class GetSessionHandleOp : public OpKernel {
39 public:
40  explicit GetSessionHandleOp(OpKernelConstruction* context)
41      : OpKernel(context) {}
42
43  void Compute(OpKernelContext* ctx) override {
44    const Tensor& val = ctx->input(0);
45    int64 id = ctx->session_state()->GetNewId();
46    TensorStore::TensorAndKey tk{val, id, requested_device()};
47    OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
48
49    Tensor* handle = nullptr;
50    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
51    if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
52      ResourceHandle resource_handle = MakeResourceHandle<Tensor>(
53          ctx, SessionState::kTensorHandleResourceTypeName,
54          tk.GetHandle(name()));
55      resource_handle.set_maybe_type_name(
56          SessionState::kTensorHandleResourceTypeName);
57      handle->scalar<ResourceHandle>()() = resource_handle;
58    } else {
59      // Legacy behavior in V1.
60      handle->flat<string>().setConstant(tk.GetHandle(name()));
61    }
62  }
63
64  TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
65};
66
67REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
68                        GetSessionHandleOp);
69REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU),
70                        GetSessionHandleOp);
71
72#define REGISTER_GPU_KERNEL(type)                         \
73  REGISTER_KERNEL_BUILDER(Name("GetSessionHandle")        \
74                              .Device(DEVICE_GPU)         \
75                              .HostMemory("handle")       \
76                              .TypeConstraint<type>("T"), \
77                          GetSessionHandleOp)             \
78  REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2")      \
79                              .Device(DEVICE_GPU)         \
80                              .HostMemory("handle")       \
81                              .TypeConstraint<type>("T"), \
82                          GetSessionHandleOp)
83
84TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
85REGISTER_GPU_KERNEL(bool);
86#undef REGISTER_GPU_KERNEL
87
88#ifdef TENSORFLOW_USE_SYCL
89#define REGISTER_SYCL_KERNEL(type)                        \
90  REGISTER_KERNEL_BUILDER(Name("GetSessionHandle")        \
91                              .Device(DEVICE_SYCL)        \
92                              .HostMemory("handle")       \
93                              .TypeConstraint<type>("T"), \
94                          GetSessionHandleOp)             \
95  REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2")      \
96                              .Device(DEVICE_SYCL)        \
97                              .HostMemory("handle")       \
98                              .TypeConstraint<type>("T"), \
99                          GetSessionHandleOp)
100
101TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
102REGISTER_SYCL_KERNEL(bool);
103#undef REGISTER_SYCL_KERNEL
104#endif  // TENSORFLOW_USE_SYCL
105
106class GetSessionTensorOp : public OpKernel {
107 public:
108  explicit GetSessionTensorOp(OpKernelConstruction* context)
109      : OpKernel(context) {}
110
111  void Compute(OpKernelContext* ctx) override {
112    const Tensor& handle = ctx->input(0);
113    const string& name = handle.scalar<string>()();
114    Tensor val;
115    OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val));
116    ctx->set_output(0, val);
117  }
118
119  TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp);
120};
121
122REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU),
123                        GetSessionTensorOp);
124
125#define REGISTER_GPU_KERNEL(type)                             \
126  REGISTER_KERNEL_BUILDER(Name("GetSessionTensor")            \
127                              .Device(DEVICE_GPU)             \
128                              .HostMemory("handle")           \
129                              .TypeConstraint<type>("dtype"), \
130                          GetSessionTensorOp)
131
132TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
133REGISTER_GPU_KERNEL(bool);
134#undef REGISTER_GPU_KERNEL
135
136#ifdef TENSORFLOW_USE_SYCL
137#define REGISTER_SYCL_KERNEL(type)                            \
138  REGISTER_KERNEL_BUILDER(Name("GetSessionTensor")            \
139                              .Device(DEVICE_SYCL)            \
140                              .HostMemory("handle")           \
141                              .TypeConstraint<type>("dtype"), \
142                          GetSessionTensorOp)
143
144TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
145REGISTER_SYCL_KERNEL(bool);
146#undef REGISTER_SYCL_KERNEL
147#endif  // TENSORFLOW_USE_SYCL
148
149class DeleteSessionTensorOp : public OpKernel {
150 public:
151  explicit DeleteSessionTensorOp(OpKernelConstruction* context)
152      : OpKernel(context) {}
153
154  void Compute(OpKernelContext* ctx) override {
155    const Tensor& handle = ctx->input(0);
156    const string& name = handle.scalar<string>()();
157    OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name));
158  }
159
160  TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp);
161};
162
163REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU),
164                        DeleteSessionTensorOp);
165REGISTER_KERNEL_BUILDER(
166    Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"),
167    DeleteSessionTensorOp);
168
169#ifdef TENSORFLOW_USE_SYCL
170REGISTER_KERNEL_BUILDER(
171    Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
172    DeleteSessionTensorOp);
173#endif  // TENSORFLOW_USE_SYCL
174}  // namespace tensorflow
175