extract_image_patches_op.cc revision 6b1b429da364b5f78adc1a3f3360c60dff2c2f41
1/* Copyright 2015 Google Inc. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/image_ops.cc.
17
18#define USE_EIGEN_TENSOR
19#define EIGEN_USE_THREADS
20
21#include "tensorflow/core/kernels/extract_image_patches_op.h"
22#include <vector>
23#include "tensorflow/core/framework/numeric_op.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/kernels/bounds_check.h"
28#include "tensorflow/core/kernels/ops_util.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/util/tensor_format.h"
33
34namespace tensorflow {
35
36typedef Eigen::ThreadPoolDevice CPUDevice;
37typedef Eigen::GpuDevice GPUDevice;
38
39static inline void ParseAttributeVec4(OpKernelConstruction* context,
40                                      const string& attr_name,
41                                      std::vector<int32>* attr) {
42  OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
43  OP_REQUIRES(
44      context, (*attr)[0] == 1 && (*attr)[3] == 1,
45      errors::Unimplemented("Only support", attr_name, "across space."));
46  OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1,
47              errors::OutOfRange(attr_name, "is out of range."));
48}
49
50template <typename Device, typename T>
51class ExtractImagePatchesOp : public UnaryOp<T> {
52 public:
53  explicit ExtractImagePatchesOp(OpKernelConstruction* context)
54      : UnaryOp<T>(context) {
55    ParseAttributeVec4(context, "ksizes", &ksizes_);
56    ParseAttributeVec4(context, "strides", &strides_);
57    ParseAttributeVec4(context, "rates", &rates_);
58    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
59  }
60
61  void Compute(OpKernelContext* context) override {
62    // Input tensor is of the following dimensions:
63    // [ batch, in_rows, in_cols, channels ]
64    const Tensor& input = context->input(0);
65    OP_REQUIRES(context, input.dims() == 4,
66                errors::InvalidArgument("input must be 4-dimensional",
67                                        input.shape().DebugString()));
68
69    const int batch = input.dim_size(0);
70    const int in_rows = input.dim_size(1);
71    const int in_cols = input.dim_size(2);
72    const int depth = input.dim_size(3);
73
74    const int ksize_rows = ksizes_[1];
75    const int ksize_cols = ksizes_[2];
76
77    const int stride_rows = strides_[1];
78    const int stride_cols = strides_[2];
79
80    const int rate_rows = rates_[1];
81    const int rate_cols = rates_[2];
82
83    const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
84    const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
85
86    int out_rows = 0, out_cols = 0;
87    int pad_rows = 0, pad_cols = 0;
88    OP_REQUIRES_OK(context, Get2dOutputSize(in_rows, in_cols, ksize_rows_eff,
89                                            ksize_cols_eff, stride_rows,
90                                            stride_cols, padding_, &out_rows,
91                                            &out_cols, &pad_rows, &pad_cols));
92
93    const std::vector<int64> out_sizes = {batch, out_rows, out_cols,
94                                          ksize_rows * ksize_cols * depth};
95    TensorShape out_shape(out_sizes);
96
97    Tensor* output = nullptr;
98    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
99
100    // If there is nothing to compute, return.
101    if (out_shape.num_elements() == 0) {
102      return;
103    }
104
105    functor::ExtractImagePatchesForward<Device, T>()(
106        context->eigen_device<Device>(), input.tensor<T, 4>(), ksize_rows,
107        ksize_cols, stride_rows, stride_cols, rate_rows, rate_cols,
108        BrainPadding2EigenPadding(padding_), output->tensor<T, 4>());
109  }
110
111 private:
112  std::vector<int32> ksizes_;
113  std::vector<int32> strides_;
114  std::vector<int32> rates_;
115
116  Padding padding_;
117
118  TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
119};
120
121// Registration of the CPU implementations.
122#define REGISTER(T)                                                          \
123  REGISTER_KERNEL_BUILDER(                                                   \
124      Name("ExtractImagePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
125      ExtractImagePatchesOp<CPUDevice, T>);
126
127TF_CALL_REAL_NUMBER_TYPES(REGISTER);
128
129#undef REGISTER
130
131#if GOOGLE_CUDA
132
133// Forward declarations of the functor specializations for GPU.
134namespace functor {
135
136#define DECLARE_GPU_SPEC(T)                                             \
137  template <>                                                           \
138  void ExtractImagePatchesForward<GPUDevice, T>::operator()(            \
139      const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,     \
140      int patch_rows, int patch_cols, int stride_rows, int stride_cols, \
141      int rate_rows, int rate_cols, const Eigen::PaddingType& padding,  \
142      typename TTypes<T, 4>::Tensor output);                            \
143  extern template struct ExtractImagePatchesForward<GPUDevice, T>;
144
145TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
146
147#undef DECLARE_GPU_SPEC
148
149}  // namespace functor
150
151// Registration of the GPU implementations.
152#define REGISTER(T)                                                          \
153  REGISTER_KERNEL_BUILDER(                                                   \
154      Name("ExtractImagePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
155      ExtractImagePatchesOp<GPUDevice, T>);
156
157TF_CALL_GPU_NUMBER_TYPES(REGISTER);
158
159#undef REGISTER
160
161#endif  // GOOGLE_CUDA
162
163}  // namespace tensorflow
164