1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// See docs in ../ops/nn_ops.cc.
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/kernels/pad_op.h"
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory>
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <string>
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <utility>
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
2656313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h"
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/register_types.h"
303ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h"
313ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor_shape.h"
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h"
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h"
3456313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/logging.h"
353ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice;
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::GpuDevice GPUDevice;
413e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL
423e975ea978bac4d861bb09328b06f3c316212611Andrew Harptypedef Eigen::SyclDevice SYCLDevice;
43cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang#endif  // TENSORFLOW_USE_SYCL
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
45cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tangtemplate <typename Device, typename T, typename Tpadding>
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass PadOp : public OpKernel {
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {}
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Compute(OpKernelContext* context) override {
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const Tensor& in0 = context->input(0);
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const Tensor& in1 = context->input(1);
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const int dims = in0.dims();
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    static const int kMinDims = 0;
5513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan    static const int kMaxDims = 6;
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                errors::Unimplemented("inputs rank not in [", kMinDims, ",",
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                      kMaxDims, "]: ", dims));
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        context,
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                in1.shape().DebugString()));
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const int fixed_dims =
6502dff6d0d838397860b6ff5256413b54da482996Josh Levenberg        (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
6602dff6d0d838397860b6ff5256413b54da482996Josh Levenberg                                                                      : dims;
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        context, fixed_dims == in1.dim_size(0),
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        errors::InvalidArgument(
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            "The first dimension of paddings must be the rank of inputs",
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            in1.shape().DebugString(), " ", in0.shape().DebugString()));
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
73a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan    T pad_value(0);
74a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan    if (context->num_inputs() == 3) {
75a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan      const Tensor& constant_values = context->input(2);
76a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan      OP_REQUIRES(
77a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan          context, TensorShapeUtils::IsScalar(constant_values.shape()),
78a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan          errors::InvalidArgument("constant_values must be a scalar. Found: ",
79a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                                  constant_values.shape().DebugString()));
80a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan      pad_value = context->input(2).scalar<T>()();
81a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan    }
82a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    // Compute the shape of the output tensor, and allocate it.
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    TensorShape output_shape;
85cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang    typename TTypes<Tpadding>::ConstMatrix paddings = in1.matrix<Tpadding>();
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for (int d = 0; d < fixed_dims; ++d) {
87cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang      const Tpadding before_d =
88cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang          paddings(d, 0);                       // Pad before existing elements.
89cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang      const Tpadding after_d = paddings(d, 1);  // Pad after existing elements.
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                  errors::InvalidArgument("Paddings must be non-negative: ",
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                          before_d, " ", after_d));
930307b9d8569011a94535322899d375a79c49df80David G. Andersen      const int64 size_d =
9402dff6d0d838397860b6ff5256413b54da482996Josh Levenberg          (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      output_shape.AddDim(before_d + size_d + after_d);
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
97a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower
98a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower    // If there is no padding to be done, forward the input to output.
99a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower    if (output_shape.num_elements() == in0.NumElements()) {
100a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      // When num_elements == 0, shape may have changed.
101a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      Tensor out;
102a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      CHECK(out.CopyFrom(in0, output_shape));
103a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      context->set_output(0, out);
104a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower      return;
105a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower    }
106a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Tensor* output = nullptr;
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    // Invoke the dims-specific implementation.
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    switch (fixed_dims) {
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 0:
113a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output);
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 1:
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        // TODO(irving): Once Pad doesn't need a scalar special case,
11702dff6d0d838397860b6ff5256413b54da482996Josh Levenberg        // change flat to tensor.  That is, once !allow_legacy_scalars().
118a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<1>(context, in0.flat<T>(), paddings, pad_value, output);
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 2:
121a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output);
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 3:
124a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output);
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 4:
127a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output);
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 5:
130a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output);
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
13213b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan      case 6:
133a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan        Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output);
13413b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan        break;
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      default:
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        OP_REQUIRES(context, false,
13713b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan                    errors::InvalidArgument("Only ranks up to 6 supported: ",
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                            in0.shape().DebugString()));
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <int Dims>
144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Operate(OpKernelContext* context,
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               typename TTypes<T, Dims>::ConstTensor input,
146cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang               typename TTypes<Tpadding>::ConstMatrix paddings, T pad_value,
147a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan               Tensor* output) {
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_EQ(Dims, paddings.dimension(0));
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_EQ(2, paddings.dimension(1));
150cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang    Eigen::array<Eigen::IndexPair<Tpadding>, Dims> paddings_array;
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for (int i = 0; i < Dims; ++i) {
1525c3977c297f07d1cc14591844e5df202b1994c85Benoit Steiner      paddings_array[i] = {paddings(i, 0), paddings(i, 1)};
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
154cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang    functor::Pad<Device, T, Tpadding, Dims> functor;
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
156a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan            paddings_array, pad_value);
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
160cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang#define REGISTER_KERNEL(type)                                     \
161cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
162cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_CPU)                 \
163cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<type>("T")          \
164cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int32>("Tpaddings") \
165cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings"),            \
166cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<CPUDevice, type, int32>);         \
167cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
168cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_CPU)                 \
169cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<type>("T")          \
170cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int64>("Tpaddings") \
171cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings"),            \
172cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<CPUDevice, type, int64>);         \
173cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
174cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_CPU)                 \
175cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<type>("T")          \
176cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int32>("Tpaddings") \
177cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings")             \
178cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("constant_values"),     \
179cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<CPUDevice, type, int32>);         \
180cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
181cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_CPU)                 \
182cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<type>("T")          \
183cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int64>("Tpaddings") \
184cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings")             \
185cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("constant_values"),     \
186cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<CPUDevice, type, int64>);
187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
1887f12947e4f31cdf9a0cca291a653980fa204d686Benoit SteinerTF_CALL_POD_TYPES(REGISTER_KERNEL);
189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#undef REGISTER_KERNEL
190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA
192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Forward declarations of the functor specializations for GPU.
193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor {
194a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan#define DECLARE_GPU_SPEC(T, Dims)                                         \
195a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan  template <>                                                             \
196cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  void Pad<GPUDevice, T, int32, Dims>::operator()(                        \
197a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan      const GPUDevice& d, typename TTypes<T, Dims>::Tensor output,        \
198a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan      typename TTypes<T, Dims>::ConstTensor input,                        \
1995c3977c297f07d1cc14591844e5df202b1994c85Benoit Steiner      Eigen::array<Eigen::IndexPair<int32>, Dims> paddings, T pad_value); \
200cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  extern template struct Pad<GPUDevice, T, int32, Dims>;                  \
201cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  template <>                                                             \
202cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  void Pad<GPUDevice, T, int64, Dims>::operator()(                        \
203cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang      const GPUDevice& d, typename TTypes<T, Dims>::Tensor output,        \
204cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang      typename TTypes<T, Dims>::ConstTensor input,                        \
205cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang      Eigen::array<Eigen::IndexPair<int64>, Dims> paddings, T pad_value); \
206cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  extern template struct Pad<GPUDevice, T, int64, Dims>;
207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPECS(T) \
209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 0);    \
210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 1);    \
211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 2);    \
212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 3);    \
213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 4);    \
21413b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan  DECLARE_GPU_SPEC(T, 5);    \
21513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan  DECLARE_GPU_SPEC(T, 6);
216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
217f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace functor
219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Registration of the GPU implementations.
221079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan#define REGISTER_GPU_KERNEL(T)                                    \
222079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
223079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .Device(DEVICE_GPU)                 \
224079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .TypeConstraint<T>("T")             \
225079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .TypeConstraint<int32>("Tpaddings") \
226079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                              .HostMemory("paddings"),            \
227cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<GPUDevice, T, int32>);            \
228cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
229cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_GPU)                 \
230cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<T>("T")             \
231cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int64>("Tpaddings") \
232cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings"),            \
233cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<GPUDevice, T, int64>);            \
234a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan  REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
235a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .Device(DEVICE_GPU)                 \
236a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .TypeConstraint<T>("T")             \
237a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .TypeConstraint<int32>("Tpaddings") \
238a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .HostMemory("paddings")             \
239a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .HostMemory("constant_values"),     \
240cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<GPUDevice, T, int32>)             \
241cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
242cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_GPU)                 \
243cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<T>("T")             \
244cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int64>("Tpaddings") \
245cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings")             \
246cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("constant_values"),     \
247cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<GPUDevice, T, int64>)
248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
251ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// A special GPU kernel for int32.
252ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// TODO(b/25387198): Also enable int32 in device memory. This kernel
253ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// registration requires all int32 inputs and outputs to be in host memory.
254ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay VasudevanREGISTER_KERNEL_BUILDER(Name("Pad")
255ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .Device(DEVICE_GPU)
256ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .TypeConstraint<int32>("T")
257079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan                            .TypeConstraint<int32>("Tpaddings")
258ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("input")
259ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("paddings")
260ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("output"),
261cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int32>);
262cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("Pad")
263cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .Device(DEVICE_GPU)
264cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int32>("T")
265cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int64>("Tpaddings")
266cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("input")
267cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("paddings")
268cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("output"),
269cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int64>);
270a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ RyanREGISTER_KERNEL_BUILDER(Name("PadV2")
271a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .Device(DEVICE_GPU)
272a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .TypeConstraint<int32>("T")
273a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .TypeConstraint<int32>("Tpaddings")
274a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("input")
275a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("paddings")
276a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("constant_values")
277a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("output"),
278cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int32>);
279cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("PadV2")
280cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .Device(DEVICE_GPU)
281cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int32>("T")
282cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int64>("Tpaddings")
283cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("input")
284cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("paddings")
285cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("constant_values")
286cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("output"),
287cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int64>);
288fe056f0b5e52db86766761f5e6446a89c1aa3938Vijay Vasudevan#endif
289ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan
2903e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL
2913e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// Registration of the GPU implementations.
2923e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#define REGISTER_SYCL_KERNEL(T)                                   \
2933e975ea978bac4d861bb09328b06f3c316212611Andrew Harp  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
2943e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .Device(DEVICE_SYCL)                \
2953e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .TypeConstraint<T>("T")             \
2963e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .TypeConstraint<int32>("Tpaddings") \
2973e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                              .HostMemory("paddings"),            \
298cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<SYCLDevice, T, int32>);           \
299cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("Pad")                             \
300cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_SYCL)                \
301cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<T>("T")             \
302cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int64>("Tpaddings") \
303cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings"),            \
304cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<SYCLDevice, T, int64>);           \
305a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan  REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
306a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .Device(DEVICE_SYCL)                \
307a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .TypeConstraint<T>("T")             \
308a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .TypeConstraint<int32>("Tpaddings") \
309a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .HostMemory("paddings")             \
310a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                              .HostMemory("constant_values"),     \
311cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<SYCLDevice, T, int32>)            \
312cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang  REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
313cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .Device(DEVICE_SYCL)                \
314cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<T>("T")             \
315cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .TypeConstraint<int64>("Tpaddings") \
316cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("paddings")             \
317cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                              .HostMemory("constant_values"),     \
318cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                          PadOp<SYCLDevice, T, int64>)
3193e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
3201b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan HseuTF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
3213e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_KERNEL_BUILDER(Name("Pad")
3223e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .Device(DEVICE_SYCL)
3233e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .TypeConstraint<int32>("T")
3243e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .TypeConstraint<int32>("Tpaddings")
3253e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .HostMemory("input")
3263e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .HostMemory("paddings")
3273e975ea978bac4d861bb09328b06f3c316212611Andrew Harp                            .HostMemory("output"),
328cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int32>);
329cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("Pad")
330cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .Device(DEVICE_SYCL)
331cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int32>("T")
332cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int64>("Tpaddings")
333cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("input")
334cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("paddings")
335cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("output"),
336cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int64>);
337a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ RyanREGISTER_KERNEL_BUILDER(Name("PadV2")
338a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .Device(DEVICE_SYCL)
339a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .TypeConstraint<int32>("T")
340a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .TypeConstraint<int32>("Tpaddings")
341a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("input")
342a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("paddings")
343a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("constant_values")
344a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan                            .HostMemory("output"),
345cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int32>);
346cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("PadV2")
347cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .Device(DEVICE_SYCL)
348cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int32>("T")
349cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .TypeConstraint<int64>("Tpaddings")
350cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("input")
351cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("paddings")
352cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("constant_values")
353cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                            .HostMemory("output"),
354cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang                        PadOp<CPUDevice, int32, int64>);
3551b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu#undef REGISTER_SYCL_KERNEL
356cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang#endif  // TENSORFLOW_USE_SYCL
3573e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
358f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // end namespace tensorflow
359