pad_op.cc revision 3e975ea978bac4d861bb09328b06f3c316212611
1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// See docs in ../ops/nn_ops.cc.
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/kernels/pad_op.h"
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory>
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <string>
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <utility>
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
2656313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h"
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/register_types.h"
303ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h"
313ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor_shape.h"
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h"
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h"
3456313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/logging.h"
353ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice;
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::GpuDevice GPUDevice;
413e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL
423e975ea978bac4d861bb09328b06f3c316212611Andrew Harptypedef Eigen::SyclDevice SYCLDevice;
433e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#endif // TENSORFLOW_USE_SYCL
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename T>
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass PadOp : public OpKernel {
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {}
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Compute(OpKernelContext* context) override {
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const Tensor& in0 = context->input(0);
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const Tensor& in1 = context->input(1);
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const int dims = in0.dims();
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    static const int kMinDims = 0;
5513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan    static const int kMaxDims = 6;
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                errors::Unimplemented("inputs rank not in [", kMinDims, ",",
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                      kMaxDims, "]: ", dims));
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        context,
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                in1.shape().DebugString()));
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const int fixed_dims =
6502dff6d0d838397860b6ff5256413b54da482996Josh Levenberg        (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
6602dff6d0d838397860b6ff5256413b54da482996Josh Levenberg                                                                      : dims;
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        context, fixed_dims == in1.dim_size(0),
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        errors::InvalidArgument(
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            "The first dimension of paddings must be the rank of inputs",
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            in1.shape().DebugString(), " ", in0.shape().DebugString()));
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    // Compute the shape of the output tensor, and allocate it.
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    TensorShape output_shape;
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>();
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for (int d = 0; d < fixed_dims; ++d) {
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const int32 before_d = paddings(d, 0);  // Pad before existing elements.
7859f1eba5fb94506a205fa2e81145667754739da5Martin Wicke      const int32 after_d = paddings(d, 1);   // Pad after existing elements.
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                  errors::InvalidArgument("Paddings must be non-negative: ",
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                          before_d, " ", after_d));
820307b9d8569011a94535322899d375a79c49df80David G. Andersen      const int64 size_d =
8302dff6d0d838397860b6ff5256413b54da482996Josh Levenberg          (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      output_shape.AddDim(before_d + size_d + after_d);
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
86a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower
87a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower    // If there is no padding to be done, forward the input to output.
88a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower    if (output_shape.num_elements() == in0.NumElements()) {
89a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      // When num_elements == 0, shape may have changed.
90a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      Tensor out;
91a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      CHECK(out.CopyFrom(in0, output_shape));
92a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      context->set_output(0, out);
93a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      return;
94a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower    }
95a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Tensor* output = nullptr;
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    // Invoke the dims-specific implementation.
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    switch (fixed_dims) {
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 0:
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<0>(context, in0.tensor<T, 0>(), paddings, output);
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 1:
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        // TODO(irving): Once Pad doesn't need a scalar special case,
10602dff6d0d838397860b6ff5256413b54da482996Josh Levenberg        // change flat to tensor.  That is, once !allow_legacy_scalars().
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<1>(context, in0.flat<T>(), paddings, output);
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 2:
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<2>(context, in0.tensor<T, 2>(), paddings, output);
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 3:
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<3>(context, in0.tensor<T, 3>(), paddings, output);
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 4:
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<4>(context, in0.tensor<T, 4>(), paddings, output);
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 5:
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<5>(context, in0.tensor<T, 5>(), paddings, output);
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
12113b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan      case 6:
12213b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan        Operate<6>(context, in0.tensor<T, 6>(), paddings, output);
12313b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan        break;
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      default:
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        OP_REQUIRES(context, false,
12613b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan                    errors::InvalidArgument("Only ranks up to 6 supported: ",
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                            in0.shape().DebugString()));
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <int Dims>
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Operate(OpKernelContext* context,
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               typename TTypes<T, Dims>::ConstTensor input,
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               TTypes<int32>::ConstMatrix paddings, Tensor* output) {
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_EQ(Dims, paddings.dimension(0));
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_EQ(2, paddings.dimension(1));
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Eigen::array<std::pair<int32, int32>, Dims> paddings_array;
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for (int i = 0; i < Dims; ++i) {
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1));
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    functor::Pad<Device, T, Dims> functor;
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            paddings_array);
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER_KERNEL(type)                            \
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  REGISTER_KERNEL_BUILDER(Name("Pad")                    \
150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .Device(DEVICE_CPU)        \
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .TypeConstraint<type>("T") \
152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .HostMemory("paddings"),   \
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                          PadOp<CPUDevice, type>)
154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
1557f12947e4f31cdf9a0cca291a653980fa204d686Benoit SteinerTF_CALL_POD_TYPES(REGISTER_KERNEL);
156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#undef REGISTER_KERNEL
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA
159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Forward declarations of the functor specializations for GPU.
160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor {
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPEC(T, Dims)                                  \
162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <>                                                      \
163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Pad<GPUDevice, T, Dims>::operator()(                        \
164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      typename TTypes<T, Dims>::ConstTensor input,                 \
166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      Eigen::array<std::pair<int32, int32>, Dims> paddings);       \
167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  extern template struct Pad<GPUDevice, T, Dims>;
168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPECS(T) \
170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 0);    \
171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 1);    \
172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 2);    \
173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 3);    \
174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 4);    \
17513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan  DECLARE_GPU_SPEC(T, 5);    \
17613b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan  DECLARE_GPU_SPEC(T, 6);
177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace functor
180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Registration of the GPU implementations.
182079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan#define REGISTER_GPU_KERNEL(T)                                    \
183079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
184079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .Device(DEVICE_GPU)                 \
185079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .TypeConstraint<T>("T")             \
186079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .TypeConstraint<int32>("Tpaddings") \
187079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .HostMemory("paddings"),            \
188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                          PadOp<GPUDevice, T>)
189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
192ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// A special GPU kernel for int32.
193ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// TODO(b/25387198): Also enable int32 in device memory. This kernel
194ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// registration requires all int32 inputs and outputs to be in host memory.
195ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay VasudevanREGISTER_KERNEL_BUILDER(Name("Pad")
196ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .Device(DEVICE_GPU)
197ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .TypeConstraint<int32>("T")
198079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                            .TypeConstraint<int32>("Tpaddings")
199ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("input")
200ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("paddings")
201ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("output"),
202ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                        PadOp<CPUDevice, int32>);
203fe056f0b5e52db86766761f5e6446a89c1aa3938Vijay Vasudevan#endif
204ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan
2053e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL
2063e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// Registration of the GPU implementations.
2073e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#define REGISTER_SYCL_KERNEL(T)                                   \
2083e975ea978bac4d861bb09328b06f3c316212611Andrew Harp  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
2093e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .Device(DEVICE_SYCL)                \
2103e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .TypeConstraint<T>("T")             \
2113e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .TypeConstraint<int32>("Tpaddings") \
2123e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .HostMemory("paddings"),            \
2133e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                          PadOp<SYCLDevice, T>)
2143e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
2153e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_SYCL_KERNEL(float);
2163e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_SYCL_KERNEL(double);
2173e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
2183e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// A special GPU kernel for int32.
2193e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// TODO(b/25387198): Also enable int32 in device memory. This kernel
2203e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// registration requires all int32 inputs and outputs to be in host memory.
2213e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_KERNEL_BUILDER(Name("Pad")
2223e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .Device(DEVICE_SYCL)
2233e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .TypeConstraint<int32>("T")
2243e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .TypeConstraint<int32>("Tpaddings")
2253e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .HostMemory("input")
2263e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .HostMemory("paddings")
2273e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .HostMemory("output"),
2283e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                        PadOp<CPUDevice, int32>);
2293e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#endif // TENSORFLOW_USE_SYCL
2303e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
231f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // end namespace tensorflow
232