pad_op.cc revision 3e975ea978bac4d861bb09328b06f3c316212611
1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// See docs in ../ops/nn_ops.cc. 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/kernels/pad_op.h" 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory> 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <string> 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <utility> 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 2656313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op.h" 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h" 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/register_types.h" 303ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor.h" 313ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/framework/tensor_shape.h" 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h" 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h" 3456313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/logging.h" 353ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h" 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice; 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::GpuDevice GPUDevice; 413e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL 423e975ea978bac4d861bb09328b06f3c316212611Andrew Harptypedef Eigen::SyclDevice SYCLDevice; 433e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#endif // TENSORFLOW_USE_SYCL 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename T> 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass PadOp : public OpKernel { 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {} 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Compute(OpKernelContext* context) override { 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& in0 = context->input(0); 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& in1 = context->input(1); 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int dims = in0.dims(); 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static const int kMinDims = 0; 5513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan static const int kMaxDims = 6; 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::Unimplemented("inputs rank not in [", kMinDims, ",", 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur kMaxDims, "]: ", dims)); 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES( 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context, 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2, 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("paddings must be a matrix with 2 columns: ", 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in1.shape().DebugString())); 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int fixed_dims = 6502dff6d0d838397860b6ff5256413b54da482996Josh Levenberg (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1 6602dff6d0d838397860b6ff5256413b54da482996Josh Levenberg : dims; 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES( 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context, fixed_dims == in1.dim_size(0), 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument( 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur "The first dimension of paddings must be the rank of inputs", 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in1.shape().DebugString(), " ", in0.shape().DebugString())); 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Compute the shape of the output tensor, and allocate it. 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorShape output_shape; 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>(); 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for (int d = 0; d < fixed_dims; ++d) { 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int32 before_d = paddings(d, 0); // Pad before existing elements. 7859f1eba5fb94506a205fa2e81145667754739da5Martin Wicke const int32 after_d = paddings(d, 1); // Pad after existing elements. 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, before_d >= 0 && after_d >= 0, 80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("Paddings must be non-negative: ", 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur before_d, " ", after_d)); 820307b9d8569011a94535322899d375a79c49df80David G. Andersen const int64 size_d = 8302dff6d0d838397860b6ff5256413b54da482996Josh Levenberg (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d); 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur output_shape.AddDim(before_d + size_d + after_d); 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 86a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower 87a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower // If there is no padding to be done, forward the input to output. 88a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower if (output_shape.num_elements() == in0.NumElements()) { 89a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower // When num_elements == 0, shape may have changed. 90a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower Tensor out; 91a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower CHECK(out.CopyFrom(in0, output_shape)); 92a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower context->set_output(0, out); 93a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower return; 94a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower } 95a1f85a49b77ae8bcf696caf872d09d1649658512A. Unique TensorFlower 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Tensor* output = nullptr; 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Invoke the dims-specific implementation. 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur switch (fixed_dims) { 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 0: 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<0>(context, in0.tensor<T, 0>(), paddings, output); 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 1: 105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // TODO(irving): Once Pad doesn't need a scalar special case, 10602dff6d0d838397860b6ff5256413b54da482996Josh Levenberg // change flat to tensor. That is, once !allow_legacy_scalars(). 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<1>(context, in0.flat<T>(), paddings, output); 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 2: 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<2>(context, in0.tensor<T, 2>(), paddings, output); 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 3: 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<3>(context, in0.tensor<T, 3>(), paddings, output); 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 4: 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<4>(context, in0.tensor<T, 4>(), paddings, output); 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 5: 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<5>(context, in0.tensor<T, 5>(), paddings, output); 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 12113b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan case 6: 12213b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan Operate<6>(context, in0.tensor<T, 6>(), paddings, output); 12313b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan break; 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur default: 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, false, 12613b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan errors::InvalidArgument("Only ranks up to 6 supported: ", 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in0.shape().DebugString())); 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <int Dims> 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Operate(OpKernelContext* context, 134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<T, Dims>::ConstTensor input, 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TTypes<int32>::ConstMatrix paddings, Tensor* output) { 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(Dims, paddings.dimension(0)); 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(2, paddings.dimension(1)); 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Eigen::array<std::pair<int32, int32>, Dims> paddings_array; 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for (int i = 0; i < Dims; ++i) { 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1)); 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur functor::Pad<Device, T, Dims> functor; 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input, 144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur paddings_array); 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER_KERNEL(type) \ 149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur REGISTER_KERNEL_BUILDER(Name("Pad") \ 150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .Device(DEVICE_CPU) \ 151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .TypeConstraint<type>("T") \ 152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .HostMemory("paddings"), \ 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur PadOp<CPUDevice, type>) 154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1557f12947e4f31cdf9a0cca291a653980fa204d686Benoit SteinerTF_CALL_POD_TYPES(REGISTER_KERNEL); 156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#undef REGISTER_KERNEL 157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA 159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Forward declarations of the functor specializations for GPU. 160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor { 161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPEC(T, Dims) \ 162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <> \ 163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Pad<GPUDevice, T, Dims>::operator()( \ 164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ 165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<T, Dims>::ConstTensor input, \ 166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Eigen::array<std::pair<int32, int32>, Dims> paddings); \ 167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur extern template struct Pad<GPUDevice, T, Dims>; 168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPECS(T) \ 170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 0); \ 171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 1); \ 172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 2); \ 173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 3); \ 174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 4); \ 17513b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan DECLARE_GPU_SPEC(T, 5); \ 17613b63bd87e53fdb01cc87d3030f79c73bd487aa0Vijay Vasudevan DECLARE_GPU_SPEC(T, 6); 177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace functor 180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Registration of the GPU implementations. 182079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan#define REGISTER_GPU_KERNEL(T) \ 183079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan REGISTER_KERNEL_BUILDER(Name("Pad") \ 184079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .Device(DEVICE_GPU) \ 185079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .TypeConstraint<T>("T") \ 186079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .TypeConstraint<int32>("Tpaddings") \ 187079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .HostMemory("paddings"), \ 188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur PadOp<GPUDevice, T>) 189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); 191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 192ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// A special GPU kernel for int32. 193ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// TODO(b/25387198): Also enable int32 in device memory. This kernel 194ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// registration requires all int32 inputs and outputs to be in host memory. 195ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay VasudevanREGISTER_KERNEL_BUILDER(Name("Pad") 196ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .Device(DEVICE_GPU) 197ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .TypeConstraint<int32>("T") 198079990d8b6bb4c60f23c4d0cc9ee29190ff13b9aVijay Vasudevan .TypeConstraint<int32>("Tpaddings") 199ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("input") 200ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("paddings") 201ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("output"), 202ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan PadOp<CPUDevice, int32>); 203fe056f0b5e52db86766761f5e6446a89c1aa3938Vijay Vasudevan#endif 204ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan 2053e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL 2063e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// Registration of the GPU implementations. 2073e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#define REGISTER_SYCL_KERNEL(T) \ 2083e975ea978bac4d861bb09328b06f3c316212611Andrew Harp REGISTER_KERNEL_BUILDER(Name("Pad") \ 2093e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .Device(DEVICE_SYCL) \ 2103e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<T>("T") \ 2113e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<int32>("Tpaddings") \ 2123e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("paddings"), \ 2133e975ea978bac4d861bb09328b06f3c316212611Andrew Harp PadOp<SYCLDevice, T>) 2143e975ea978bac4d861bb09328b06f3c316212611Andrew Harp 2153e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_SYCL_KERNEL(float); 2163e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_SYCL_KERNEL(double); 2173e975ea978bac4d861bb09328b06f3c316212611Andrew Harp 2183e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// A special GPU kernel for int32. 2193e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// TODO(b/25387198): Also enable int32 in device memory. This kernel 2203e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// registration requires all int32 inputs and outputs to be in host memory. 2213e975ea978bac4d861bb09328b06f3c316212611Andrew HarpREGISTER_KERNEL_BUILDER(Name("Pad") 2223e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .Device(DEVICE_SYCL) 2233e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<int32>("T") 2243e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .TypeConstraint<int32>("Tpaddings") 2253e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("input") 2263e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("paddings") 2273e975ea978bac4d861bb09328b06f3c316212611Andrew Harp .HostMemory("output"), 2283e975ea978bac4d861bb09328b06f3c316212611Andrew Harp PadOp<CPUDevice, int32>); 2293e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#endif // TENSORFLOW_USE_SYCL 2303e975ea978bac4d861bb09328b06f3c316212611Andrew Harp 231f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end namespace tensorflow 232