pad_op.cc revision 02dff6d0d838397860b6ff5256413b54da482996
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur/* Copyright 2015 Google Inc. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// See docs in ../ops/nn_ops.cc. 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/kernels/pad_op.h" 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory> 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <string> 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <utility> 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 2656313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op.h" 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h" 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/register_types.h" 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h" 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h" 3256313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/logging.h" 3356313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/port.h" 34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/tensor.h" 3556313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/public/tensor_shape.h" 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice; 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::GpuDevice GPUDevice; 41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename T> 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass PadOp : public OpKernel { 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {} 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Compute(OpKernelContext* context) override { 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& in0 = context->input(0); 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const Tensor& in1 = context->input(1); 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int dims = in0.dims(); 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static const int kMinDims = 0; 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static const int kMaxDims = 5; 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::Unimplemented("inputs rank not in [", kMinDims, ",", 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur kMaxDims, "]: ", dims)); 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES( 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context, 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2, 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("paddings must be a matrix with 2 columns: ", 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in1.shape().DebugString())); 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int fixed_dims = 6202dff6d0d838397860b6ff5256413b54da482996Josh Levenberg (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1 6302dff6d0d838397860b6ff5256413b54da482996Josh Levenberg : dims; 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES( 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur context, fixed_dims == in1.dim_size(0), 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument( 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur "The first dimension of paddings must be the rank of inputs", 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in1.shape().DebugString(), " ", in0.shape().DebugString())); 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Compute the shape of the output tensor, and allocate it. 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TensorShape output_shape; 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>(); 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for (int d = 0; d < fixed_dims; ++d) { 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int32 before_d = paddings(d, 0); // Pad before existing elements. 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int32 after_d = paddings(d, 1); // Pad after exisitng elements. 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, before_d >= 0 && after_d >= 0, 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("Paddings must be non-negative: ", 78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur before_d, " ", after_d)); 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int size_d = 8002dff6d0d838397860b6ff5256413b54da482996Josh Levenberg (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d); 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur output_shape.AddDim(before_d + size_d + after_d); 82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Tensor* output = nullptr; 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Invoke the dims-specific implementation. 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur switch (fixed_dims) { 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 0: 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<0>(context, in0.tensor<T, 0>(), paddings, output); 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 1: 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // TODO(irving): Once Pad doesn't need a scalar special case, 9302dff6d0d838397860b6ff5256413b54da482996Josh Levenberg // change flat to tensor. That is, once !allow_legacy_scalars(). 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<1>(context, in0.flat<T>(), paddings, output); 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 2: 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<2>(context, in0.tensor<T, 2>(), paddings, output); 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 3: 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<3>(context, in0.tensor<T, 3>(), paddings, output); 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 4: 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<4>(context, in0.tensor<T, 4>(), paddings, output); 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur case 5: 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Operate<5>(context, in0.tensor<T, 5>(), paddings, output); 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur break; 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur default: 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur OP_REQUIRES(context, false, 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur errors::InvalidArgument("Only ranks up to 5 supported: ", 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur in0.shape().DebugString())); 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <int Dims> 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Operate(OpKernelContext* context, 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<T, Dims>::ConstTensor input, 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TTypes<int32>::ConstMatrix paddings, Tensor* output) { 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(Dims, paddings.dimension(0)); 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CHECK_EQ(2, paddings.dimension(1)); 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Eigen::array<std::pair<int32, int32>, Dims> paddings_array; 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for (int i = 0; i < Dims; ++i) { 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1)); 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur functor::Pad<Device, T, Dims> functor; 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input, 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur paddings_array); 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER_KERNEL(type) \ 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur REGISTER_KERNEL_BUILDER(Name("Pad") \ 134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .Device(DEVICE_CPU) \ 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .TypeConstraint<type>("T") \ 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .HostMemory("paddings"), \ 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur PadOp<CPUDevice, type>) 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_ALL_TYPES(REGISTER_KERNEL); 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#undef REGISTER_KERNEL 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Forward declarations of the functor specializations for GPU. 144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor { 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPEC(T, Dims) \ 146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <> \ 147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Pad<GPUDevice, T, Dims>::operator()( \ 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ 149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<T, Dims>::ConstTensor input, \ 150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Eigen::array<std::pair<int32, int32>, Dims> paddings); \ 151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur extern template struct Pad<GPUDevice, T, Dims>; 152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPECS(T) \ 154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 0); \ 155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 1); \ 156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 2); \ 157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 3); \ 158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 4); \ 159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DECLARE_GPU_SPEC(T, 5); 160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace functor 163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Registration of the GPU implementations. 165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER_GPU_KERNEL(T) \ 166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur REGISTER_KERNEL_BUILDER(Name("Pad") \ 167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .Device(DEVICE_GPU) \ 168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .TypeConstraint<T>("T") \ 169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur .HostMemory("paddings"), \ 170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur PadOp<GPUDevice, T>) 171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); 173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // GOOGLE_CUDA 174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 175ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// A special GPU kernel for int32. 176ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// TODO(b/25387198): Also enable int32 in device memory. This kernel 177ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// registration requires all int32 inputs and outputs to be in host memory. 178ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay VasudevanREGISTER_KERNEL_BUILDER(Name("Pad") 179ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .Device(DEVICE_GPU) 180ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .TypeConstraint<int32>("T") 181ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("input") 182ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("paddings") 183ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan .HostMemory("output"), 184ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan PadOp<CPUDevice, int32>); 185ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan 186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end namespace tensorflow 187