extract_image_patches_op.cc revision 6b1b429da364b5f78adc1a3f3360c60dff2c2f41
1/* Copyright 2015 Google Inc. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// See docs in ../ops/image_ops.cc. 17 18#define USE_EIGEN_TENSOR 19#define EIGEN_USE_THREADS 20 21#include "tensorflow/core/kernels/extract_image_patches_op.h" 22#include <vector> 23#include "tensorflow/core/framework/numeric_op.h" 24#include "tensorflow/core/framework/op_kernel.h" 25#include "tensorflow/core/framework/register_types.h" 26#include "tensorflow/core/framework/tensor.h" 27#include "tensorflow/core/kernels/bounds_check.h" 28#include "tensorflow/core/kernels/ops_util.h" 29#include "tensorflow/core/lib/core/errors.h" 30#include "tensorflow/core/platform/logging.h" 31#include "tensorflow/core/platform/macros.h" 32#include "tensorflow/core/util/tensor_format.h" 33 34namespace tensorflow { 35 36typedef Eigen::ThreadPoolDevice CPUDevice; 37typedef Eigen::GpuDevice GPUDevice; 38 39static inline void ParseAttributeVec4(OpKernelConstruction* context, 40 const string& attr_name, 41 std::vector<int32>* attr) { 42 OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr)); 43 OP_REQUIRES( 44 context, (*attr)[0] == 1 && (*attr)[3] == 1, 45 errors::Unimplemented("Only support", attr_name, "across space.")); 46 OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1, 47 errors::OutOfRange(attr_name, "is out of range.")); 48} 49 50template <typename Device, typename T> 51class ExtractImagePatchesOp : public UnaryOp<T> { 52 public: 53 explicit ExtractImagePatchesOp(OpKernelConstruction* context) 54 : UnaryOp<T>(context) { 55 ParseAttributeVec4(context, "ksizes", &ksizes_); 56 ParseAttributeVec4(context, "strides", &strides_); 57 ParseAttributeVec4(context, "rates", &rates_); 58 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 59 } 60 61 void Compute(OpKernelContext* context) override { 62 // Input tensor is of the following dimensions: 63 // [ batch, in_rows, in_cols, channels ] 64 const Tensor& input = context->input(0); 65 OP_REQUIRES(context, input.dims() == 4, 66 errors::InvalidArgument("input must be 4-dimensional", 67 input.shape().DebugString())); 68 69 const int batch = input.dim_size(0); 70 const int in_rows = input.dim_size(1); 71 const int in_cols = input.dim_size(2); 72 const int depth = input.dim_size(3); 73 74 const int ksize_rows = ksizes_[1]; 75 const int ksize_cols = ksizes_[2]; 76 77 const int stride_rows = strides_[1]; 78 const int stride_cols = strides_[2]; 79 80 const int rate_rows = rates_[1]; 81 const int rate_cols = rates_[2]; 82 83 const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); 84 const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); 85 86 int out_rows = 0, out_cols = 0; 87 int pad_rows = 0, pad_cols = 0; 88 OP_REQUIRES_OK(context, Get2dOutputSize(in_rows, in_cols, ksize_rows_eff, 89 ksize_cols_eff, stride_rows, 90 stride_cols, padding_, &out_rows, 91 &out_cols, &pad_rows, &pad_cols)); 92 93 const std::vector<int64> out_sizes = {batch, out_rows, out_cols, 94 ksize_rows * ksize_cols * depth}; 95 TensorShape out_shape(out_sizes); 96 97 Tensor* output = nullptr; 98 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 99 100 // If there is nothing to compute, return. 101 if (out_shape.num_elements() == 0) { 102 return; 103 } 104 105 functor::ExtractImagePatchesForward<Device, T>()( 106 context->eigen_device<Device>(), input.tensor<T, 4>(), ksize_rows, 107 ksize_cols, stride_rows, stride_cols, rate_rows, rate_cols, 108 BrainPadding2EigenPadding(padding_), output->tensor<T, 4>()); 109 } 110 111 private: 112 std::vector<int32> ksizes_; 113 std::vector<int32> strides_; 114 std::vector<int32> rates_; 115 116 Padding padding_; 117 118 TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); 119}; 120 121// Registration of the CPU implementations. 122#define REGISTER(T) \ 123 REGISTER_KERNEL_BUILDER( \ 124 Name("ExtractImagePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 125 ExtractImagePatchesOp<CPUDevice, T>); 126 127TF_CALL_REAL_NUMBER_TYPES(REGISTER); 128 129#undef REGISTER 130 131#if GOOGLE_CUDA 132 133// Forward declarations of the functor specializations for GPU. 134namespace functor { 135 136#define DECLARE_GPU_SPEC(T) \ 137 template <> \ 138 void ExtractImagePatchesForward<GPUDevice, T>::operator()( \ 139 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ 140 int patch_rows, int patch_cols, int stride_rows, int stride_cols, \ 141 int rate_rows, int rate_cols, const Eigen::PaddingType& padding, \ 142 typename TTypes<T, 4>::Tensor output); \ 143 extern template struct ExtractImagePatchesForward<GPUDevice, T>; 144 145TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 146 147#undef DECLARE_GPU_SPEC 148 149} // namespace functor 150 151// Registration of the GPU implementations. 152#define REGISTER(T) \ 153 REGISTER_KERNEL_BUILDER( \ 154 Name("ExtractImagePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 155 ExtractImagePatchesOp<GPUDevice, T>); 156 157TF_CALL_GPU_NUMBER_TYPES(REGISTER); 158 159#undef REGISTER 160 161#endif // GOOGLE_CUDA 162 163} // namespace tensorflow 164