pad_op.cc revision 02dff6d0d838397860b6ff5256413b54da482996
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur/* Copyright 2015 Google Inc. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// See docs in ../ops/nn_ops.cc.
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/kernels/pad_op.h"
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <memory>
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <string>
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <utility>
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
2656313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h"
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/register_types.h"
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h"
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/types.h"
3256313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/logging.h"
3356313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/platform/port.h"
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/public/tensor.h"
3556313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "tensorflow/core/public/tensor_shape.h"
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice;
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtypedef Eigen::GpuDevice GPUDevice;
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename T>
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass PadOp : public OpKernel {
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {}
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Compute(OpKernelContext* context) override {
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const Tensor& in0 = context->input(0);
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const Tensor& in1 = context->input(1);
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const int dims = in0.dims();
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    static const int kMinDims = 0;
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    static const int kMaxDims = 5;
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                errors::Unimplemented("inputs rank not in [", kMinDims, ",",
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                      kMaxDims, "]: ", dims));
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        context,
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                in1.shape().DebugString()));
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    const int fixed_dims =
6202dff6d0d838397860b6ff5256413b54da482996Josh Levenberg        (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
6302dff6d0d838397860b6ff5256413b54da482996Josh Levenberg                                                                      : dims;
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES(
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        context, fixed_dims == in1.dim_size(0),
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        errors::InvalidArgument(
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            "The first dimension of paddings must be the rank of inputs",
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            in1.shape().DebugString(), " ", in0.shape().DebugString()));
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    // Compute the shape of the output tensor, and allocate it.
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    TensorShape output_shape;
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>();
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for (int d = 0; d < fixed_dims; ++d) {
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const int32 before_d = paddings(d, 0);  // Pad before existing elements.
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const int32 after_d = paddings(d, 1);   // Pad after exisitng elements.
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                  errors::InvalidArgument("Paddings must be non-negative: ",
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                          before_d, " ", after_d));
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const int size_d =
8002dff6d0d838397860b6ff5256413b54da482996Josh Levenberg          (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      output_shape.AddDim(before_d + size_d + after_d);
82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Tensor* output = nullptr;
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    // Invoke the dims-specific implementation.
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    switch (fixed_dims) {
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 0:
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<0>(context, in0.tensor<T, 0>(), paddings, output);
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 1:
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        // TODO(irving): Once Pad doesn't need a scalar special case,
9302dff6d0d838397860b6ff5256413b54da482996Josh Levenberg        // change flat to tensor.  That is, once !allow_legacy_scalars().
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<1>(context, in0.flat<T>(), paddings, output);
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 2:
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<2>(context, in0.tensor<T, 2>(), paddings, output);
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 3:
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<3>(context, in0.tensor<T, 3>(), paddings, output);
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 4:
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<4>(context, in0.tensor<T, 4>(), paddings, output);
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      case 5:
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        Operate<5>(context, in0.tensor<T, 5>(), paddings, output);
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        break;
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      default:
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        OP_REQUIRES(context, false,
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                    errors::InvalidArgument("Only ranks up to 5 supported: ",
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                            in0.shape().DebugString()));
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <int Dims>
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Operate(OpKernelContext* context,
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               typename TTypes<T, Dims>::ConstTensor input,
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur               TTypes<int32>::ConstMatrix paddings, Tensor* output) {
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_EQ(Dims, paddings.dimension(0));
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    CHECK_EQ(2, paddings.dimension(1));
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    Eigen::array<std::pair<int32, int32>, Dims> paddings_array;
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for (int i = 0; i < Dims; ++i) {
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1));
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    functor::Pad<Device, T, Dims> functor;
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur            paddings_array);
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER_KERNEL(type)                            \
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  REGISTER_KERNEL_BUILDER(Name("Pad")                    \
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .Device(DEVICE_CPU)        \
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .TypeConstraint<type>("T") \
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .HostMemory("paddings"),   \
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                          PadOp<CPUDevice, type>)
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_ALL_TYPES(REGISTER_KERNEL);
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#undef REGISTER_KERNEL
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Forward declarations of the functor specializations for GPU.
144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor {
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPEC(T, Dims)                                  \
146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <>                                                      \
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Pad<GPUDevice, T, Dims>::operator()(                        \
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      typename TTypes<T, Dims>::ConstTensor input,                 \
150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      Eigen::array<std::pair<int32, int32>, Dims> paddings);       \
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  extern template struct Pad<GPUDevice, T, Dims>;
152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define DECLARE_GPU_SPECS(T) \
154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 0);    \
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 1);    \
156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 2);    \
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 3);    \
158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 4);    \
159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DECLARE_GPU_SPEC(T, 5);
160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace functor
163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Registration of the GPU implementations.
165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER_GPU_KERNEL(T)                         \
166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .Device(DEVICE_GPU)      \
168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .TypeConstraint<T>("T")  \
169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                              .HostMemory("paddings"), \
170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                          PadOp<GPUDevice, T>)
171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurTF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // GOOGLE_CUDA
174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
175ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// A special GPU kernel for int32.
176ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// TODO(b/25387198): Also enable int32 in device memory. This kernel
177ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan// registration requires all int32 inputs and outputs to be in host memory.
178ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay VasudevanREGISTER_KERNEL_BUILDER(Name("Pad")
179ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .Device(DEVICE_GPU)
180ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .TypeConstraint<int32>("T")
181ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("input")
182ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("paddings")
183ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                            .HostMemory("output"),
184ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan                        PadOp<CPUDevice, int32>);
185ab34d55ce7618e52069a2e1c9e51aac5a1ea81c3Vijay Vasudevan
186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // end namespace tensorflow
187