1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// See docs in ../ops/nn_ops.cc. 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/kernels/pad_op.h" 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory> 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <string> 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <utility> 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 2656313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op.h" 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h" 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/register_types.h" 303ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h" 313ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor_shape.h" 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h" 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h" 3456313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/logging.h" 353ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h" 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice; 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::GpuDevice GPUDevice; 413e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL 423e975ea978bac4d861bb09328b06f3c316212611Andrew Harptypedef Eigen::SyclDevice SYCLDevice; 43cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang#endif // TENSORFLOW_USE_SYCL 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 45cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tangtemplate <typename Device, typename T, typename Tpadding> 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass PadOp : public OpKernel { 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {} 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Compute(OpKernelContext* context) override { 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& in0 = context->input(0); 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& in1 = context->input(1); 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int dims = in0.dims(); 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static const int kMinDims = 0; 5513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan static const int kMaxDims = 6; 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::Unimplemented("inputs rank not in [", kMinDims, ",", 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur kMaxDims, "]: ", dims)); 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES( 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context, 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2, 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("paddings must be a matrix with 2 columns: ", 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in1.shape().DebugString())); 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int fixed_dims = 6502dff6d0d838397860b6ff5256413b54da482996Josh Levenberg (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1 6602dff6d0d838397860b6ff5256413b54da482996Josh Levenberg : dims; 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES( 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context, fixed_dims == in1.dim_size(0), 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument( 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur "The first dimension of paddings must be the rank of inputs", 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in1.shape().DebugString(), " ", in0.shape().DebugString())); 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 73a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan T pad_value(0); 74a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan if (context->num_inputs() == 3) { 75a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan const Tensor& constant_values = context->input(2); 76a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan OP_REQUIRES( 77a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan context, TensorShapeUtils::IsScalar(constant_values.shape()), 78a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan errors::InvalidArgument("constant_values must be a scalar. Found: ", 79a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan constant_values.shape().DebugString())); 80a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan pad_value = context->input(2).scalar<T>()(); 81a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan } 82a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Compute the shape of the output tensor, and allocate it. 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorShape output_shape; 85cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang typename TTypes<Tpadding>::ConstMatrix paddings = in1.matrix<Tpadding>(); 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for (int d = 0; d < fixed_dims; ++d) { 87cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang const Tpadding before_d = 88cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang paddings(d, 0); // Pad before existing elements. 89cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang const Tpadding after_d = paddings(d, 1); // Pad after existing elements. 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, before_d >= 0 && after_d >= 0, 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("Paddings must be non-negative: ", 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur before_d, " ", after_d)); 930307b9d8569011a94535322899d375a79c49df80David G. Andersen const int64 size_d = 9402dff6d0d838397860b6ff5256413b54da482996Josh Levenberg (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d); 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur output_shape.AddDim(before_d + size_d + after_d); 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 97a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower 98a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower // If there is no padding to be done, forward the input to output. 99a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower if (output_shape.num_elements() == in0.NumElements()) { 100a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower // When num_elements == 0, shape may have changed. 101a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower Tensor out; 102a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower CHECK(out.CopyFrom(in0, output_shape)); 103a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower context->set_output(0, out); 104a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower return; 105a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower } 106a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Tensor* output = nullptr; 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Invoke the dims-specific implementation. 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur switch (fixed_dims) { 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 0: 113a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output); 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 1: 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // TODO(irving): Once Pad doesn't need a scalar special case, 11702dff6d0d838397860b6ff5256413b54da482996Josh Levenberg // change flat to tensor. That is, once !allow_legacy_scalars(). 118a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<1>(context, in0.flat<T>(), paddings, pad_value, output); 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 2: 121a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output); 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 3: 124a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output); 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 4: 127a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output); 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 5: 130a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output); 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 13213b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan case 6: 133a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output); 13413b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan break; 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur default: 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, false, 13713b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan errors::InvalidArgument("Only ranks up to 6 supported: ", 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in0.shape().DebugString())); 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <int Dims> 144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Operate(OpKernelContext* context, 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<T, Dims>::ConstTensor input, 146cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang typename TTypes<Tpadding>::ConstMatrix paddings, T pad_value, 147a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan Tensor* output) { 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(Dims, paddings.dimension(0)); 149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(2, paddings.dimension(1)); 150cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang Eigen::array<Eigen::IndexPair<Tpadding>, Dims> paddings_array; 151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for (int i = 0; i < Dims; ++i) { 1525c3977c297f07d1cc14591844e5df202b1994c85Benoit Steiner paddings_array[i] = {paddings(i, 0), paddings(i, 1)}; 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 154cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang functor::Pad<Device, T, Tpadding, Dims> functor; 155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input, 156a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan paddings_array, pad_value); 157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 160cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang#define REGISTER_KERNEL(type) \ 161cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("Pad") \ 162cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_CPU) \ 163cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<type>("T") \ 164cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int32>("Tpaddings") \ 165cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings"), \ 166cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, type, int32>); \ 167cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("Pad") \ 168cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_CPU) \ 169cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<type>("T") \ 170cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") \ 171cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings"), \ 172cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, type, int64>); \ 173cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("PadV2") \ 174cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_CPU) \ 175cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<type>("T") \ 176cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int32>("Tpaddings") \ 177cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") \ 178cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("constant_values"), \ 179cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, type, int32>); \ 180cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("PadV2") \ 181cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_CPU) \ 182cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<type>("T") \ 183cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") \ 184cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") \ 185cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("constant_values"), \ 186cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, type, int64>); 187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1887f12947e4f31cdf9a0cca291a653980fa204d686Benoit SteinerTF_CALL_POD_TYPES(REGISTER_KERNEL); 189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#undef REGISTER_KERNEL 190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA 192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Forward declarations of the functor specializations for GPU. 193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor { 194a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan#define DECLARE_GPU_SPEC(T, Dims) \ 195a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan template <> \ 196cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang void Pad<GPUDevice, T, int32, Dims>::operator()( \ 197a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ 198a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan typename TTypes<T, Dims>::ConstTensor input, \ 1995c3977c297f07d1cc14591844e5df202b1994c85Benoit Steiner Eigen::array<Eigen::IndexPair<int32>, Dims> paddings, T pad_value); \ 200cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang extern template struct Pad<GPUDevice, T, int32, Dims>; \ 201cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang template <> \ 202cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang void Pad<GPUDevice, T, int64, Dims>::operator()( \ 203cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ 204cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang typename TTypes<T, Dims>::ConstTensor input, \ 205cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang Eigen::array<Eigen::IndexPair<int64>, Dims> paddings, T pad_value); \ 206cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang extern template struct Pad<GPUDevice, T, int64, Dims>; 207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPECS(T) \ 209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 0); \ 210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 1); \ 211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 2); \ 212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 3); \ 213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 4); \ 21413b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan DECLARE_GPU_SPEC(T, 5); \ 21513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan DECLARE_GPU_SPEC(T, 6); 216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 217f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace functor 219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Registration of the GPU implementations. 221079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan#define REGISTER_GPU_KERNEL(T) \ 222079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan REGISTER_KERNEL_BUILDER(Name("Pad") \ 223079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .Device(DEVICE_GPU) \ 224079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .TypeConstraint<T>("T") \ 225079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .TypeConstraint<int32>("Tpaddings") \ 226079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .HostMemory("paddings"), \ 227cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<GPUDevice, T, int32>); \ 228cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("Pad") \ 229cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_GPU) \ 230cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<T>("T") \ 231cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") \ 232cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings"), \ 233cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<GPUDevice, T, int64>); \ 234a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan REGISTER_KERNEL_BUILDER(Name("PadV2") \ 235a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .Device(DEVICE_GPU) \ 236a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<T>("T") \ 237a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<int32>("Tpaddings") \ 238a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("paddings") \ 239a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("constant_values"), \ 240cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<GPUDevice, T, int32>) \ 241cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("PadV2") \ 242cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_GPU) \ 243cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<T>("T") \ 244cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") \ 245cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") \ 246cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("constant_values"), \ 247cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<GPUDevice, T, int64>) 248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); 250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 251ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// A special GPU kernel for int32. 252ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// TODO(b/25387198): Also enable int32 in device memory. This kernel 253ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// registration requires all int32 inputs and outputs to be in host memory. 254ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay VasudevanREGISTER_KERNEL_BUILDER(Name("Pad") 255ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .Device(DEVICE_GPU) 256ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .TypeConstraint<int32>("T") 257079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .TypeConstraint<int32>("Tpaddings") 258ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("input") 259ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("paddings") 260ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("output"), 261cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int32>); 262cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("Pad") 263cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_GPU) 264cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int32>("T") 265cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") 266cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("input") 267cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") 268cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("output"), 269cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int64>); 270a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ RyanREGISTER_KERNEL_BUILDER(Name("PadV2") 271a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .Device(DEVICE_GPU) 272a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<int32>("T") 273a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<int32>("Tpaddings") 274a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("input") 275a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("paddings") 276a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("constant_values") 277a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("output"), 278cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int32>); 279cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("PadV2") 280cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_GPU) 281cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int32>("T") 282cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") 283cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("input") 284cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") 285cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("constant_values") 286cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("output"), 287cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int64>); 288fe056f0b5e52db86766761f5e6446a89c1aa3938Vijay Vasudevan#endif 289ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan 2903e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL 2913e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// Registration of the GPU implementations. 2923e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#define REGISTER_SYCL_KERNEL(T) \ 2933e975ea978bac4d861bb09328b06f3c316212611Andrew Harp REGISTER_KERNEL_BUILDER(Name("Pad") \ 2943e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .Device(DEVICE_SYCL) \ 2953e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<T>("T") \ 2963e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<int32>("Tpaddings") \ 2973e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("paddings"), \ 298cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<SYCLDevice, T, int32>); \ 299cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("Pad") \ 300cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_SYCL) \ 301cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<T>("T") \ 302cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") \ 303cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings"), \ 304cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<SYCLDevice, T, int64>); \ 305a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan REGISTER_KERNEL_BUILDER(Name("PadV2") \ 306a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .Device(DEVICE_SYCL) \ 307a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<T>("T") \ 308a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<int32>("Tpaddings") \ 309a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("paddings") \ 310a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("constant_values"), \ 311cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<SYCLDevice, T, int32>) \ 312cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang REGISTER_KERNEL_BUILDER(Name("PadV2") \ 313cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_SYCL) \ 314cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<T>("T") \ 315cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") \ 316cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") \ 317cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("constant_values"), \ 318cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<SYCLDevice, T, int64>) 3193e975ea978bac4d861bb09328b06f3c316212611Andrew Harp 3201b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan HseuTF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); 3213e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_KERNEL_BUILDER(Name("Pad") 3223e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .Device(DEVICE_SYCL) 3233e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<int32>("T") 3243e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<int32>("Tpaddings") 3253e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("input") 3263e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("paddings") 3273e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("output"), 328cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int32>); 329cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("Pad") 330cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_SYCL) 331cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int32>("T") 332cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") 333cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("input") 334cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") 335cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("output"), 336cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int64>); 337a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ RyanREGISTER_KERNEL_BUILDER(Name("PadV2") 338a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .Device(DEVICE_SYCL) 339a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<int32>("T") 340a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .TypeConstraint<int32>("Tpaddings") 341a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("input") 342a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("paddings") 343a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("constant_values") 344a6773e98e97956b7adf3aa51eb3548261f51d6f7RJ Ryan .HostMemory("output"), 345cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int32>); 346cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong TangREGISTER_KERNEL_BUILDER(Name("PadV2") 347cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .Device(DEVICE_SYCL) 348cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int32>("T") 349cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .TypeConstraint<int64>("Tpaddings") 350cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("input") 351cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("paddings") 352cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("constant_values") 353cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang .HostMemory("output"), 354cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang PadOp<CPUDevice, int32, int64>); 3551b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu#undef REGISTER_SYCL_KERNEL 356cbb705f10149a11b8d17182343ef12ab2dbfd7a8Yong Tang#endif // TENSORFLOW_USE_SYCL 3573e975ea978bac4d861bb09328b06f3c316212611Andrew Harp 358f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end namespace tensorflow 359