1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_KERNELS_CONV_2D_H_ 17#define TENSORFLOW_KERNELS_CONV_2D_H_ 18 19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20#include "tensorflow/core/framework/tensor_types.h" 21#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h" 22#include "tensorflow/core/kernels/eigen_spatial_convolutions.h" 23#include "tensorflow/core/util/tensor_format.h" 24 25namespace tensorflow { 26namespace functor { 27 28// TODO(yangke): revisit these operations and in particular, see if we can 29// combine all of them into just one operation without causing nvcc to 30// timeout. 31template <typename Device, typename T, int Dims, typename IndexType> 32struct ShuffleAndReverse { 33 void operator()(const Device& d, 34 typename TTypes<T, Dims, IndexType>::ConstTensor input, 35 const Eigen::DSizes<IndexType, Dims>& order, 36 const Eigen::array<bool, Dims>& reverse_dims, 37 typename TTypes<T, Dims, IndexType>::Tensor output) { 38 output.device(d) = input.shuffle(order).reverse(reverse_dims); 39 } 40}; 41 42template <typename Device, typename T, int Dims, typename IndexType> 43struct InflatePadAndShuffle { 44 void operator()( 45 const Device& d, typename TTypes<T, Dims, IndexType>::ConstTensor input, 46 const Eigen::DSizes<IndexType, Dims>& strides, 47 const Eigen::array<Eigen::IndexPair<IndexType>, Dims>& pad_dims, 48 const Eigen::DSizes<IndexType, Dims>& order, 49 typename TTypes<T, Dims, IndexType>::Tensor output) { 50 output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order); 51 } 52}; 53 54template <typename Device, typename Input, typename Filter, typename Output> 55void SpatialConvolutionFunc(const Device& d, Output output, Input input, 56 Filter filter, int row_stride, int col_stride, 57 int row_dilation, int col_dilation, 58 const Eigen::PaddingType& padding) { 59 // Need to swap row/col when calling Eigen. 60 output.device(d) = 61 Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding, 62 col_dilation, row_dilation); 63} 64 65template <typename Device, typename T> 66struct SpatialConvolution { 67 void operator()(const Device& d, typename TTypes<T, 4>::Tensor output, 68 typename TTypes<T, 4>::ConstTensor input, 69 typename TTypes<T, 4>::ConstTensor filter, int row_stride, 70 int col_stride, int row_dilation, int col_dilation, 71 const Eigen::PaddingType& padding) { 72 SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, 73 row_dilation, col_dilation, padding); 74 } 75}; 76 77template <typename Device> 78struct SpatialConvolution<Device, Eigen::half> { 79 void operator()(const Device& d, 80 typename TTypes<Eigen::half, 4>::Tensor output, 81 typename TTypes<Eigen::half, 4>::ConstTensor input, 82 typename TTypes<Eigen::half, 4>::ConstTensor filter, 83 int row_stride, int col_stride, int row_dilation, 84 int col_dilation, const Eigen::PaddingType& padding) { 85 output.device(d) = 86 Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(), 87 col_stride, row_stride, padding, col_dilation, 88 row_dilation) 89 .cast<Eigen::half>(); 90 } 91}; 92 93template <typename Device, typename T> 94struct SpatialConvolutionBackwardInput { 95 void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward, 96 typename TTypes<T, 4>::ConstTensor kernel, 97 typename TTypes<T, 4>::ConstTensor output_backward, 98 int row_stride, int col_stride, int row_dilation, 99 int col_dilation) { 100 // Need to swap row/col when calling Eigen. 101 input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput( 102 kernel, output_backward, input_backward.dimension(2), 103 input_backward.dimension(1), col_stride, row_stride, col_dilation, 104 row_dilation); 105 } 106}; 107 108template <typename Device, typename T> 109struct SpatialConvolutionBackwardFilter { 110 void operator()(const Device& d, 111 typename TTypes<T, 4>::Tensor kernel_backward, 112 typename TTypes<T, 4>::ConstTensor input, 113 typename TTypes<T, 4>::ConstTensor output_backward, 114 int row_stride, int col_stride, int row_dilation, 115 int col_dilation) { 116 // Need to swap row/col when calling Eigen. 117 kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel( 118 input, output_backward, kernel_backward.dimension(1), 119 kernel_backward.dimension(0), col_stride, row_stride, col_dilation, 120 row_dilation); 121 } 122}; 123 124// TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h. 125// My initial attempt to do this compiled but failed in the pytest 126// due to a swigdeps error. 127template <typename Device, typename T> 128struct MatMulConvFunctor { 129 // Computes on device "d": out = in0 * in1, where * is matrix 130 // multiplication. 131 void operator()( 132 const Device& d, typename TTypes<T, 2>::Tensor out, 133 typename TTypes<T, 2>::ConstTensor in0, 134 typename TTypes<T, 2>::ConstTensor in1, 135 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) { 136 out.device(d) = in0.contract(in1, dim_pair); 137 } 138}; 139 140// Shuffles a filter tensor from: 141// [<spatial_dims>, in, out] 142// to: 143// [out, in, <spatial_dims>] 144template <typename Device, typename T, typename IndexType, int NDIMS> 145struct TransformFilter { 146 void operator()(const Device& d, 147 typename TTypes<T, NDIMS, IndexType>::ConstTensor in, 148 typename TTypes<T, NDIMS, IndexType>::Tensor out) { 149 // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together 150 // to speed up the shuffle operation. 151 Eigen::DSizes<IndexType, 3> merged_dims; 152 merged_dims[0] = in.dimension(0); // spatial dimensions 153 for (int i = 1; i < NDIMS - 2; ++i) { 154 merged_dims[0] *= in.dimension(i); 155 } 156 merged_dims[1] = in.dimension(NDIMS - 2); // input filters 157 merged_dims[2] = in.dimension(NDIMS - 1); // output filters 158 159 Eigen::DSizes<IndexType, NDIMS> expanded_dims; 160 expanded_dims[0] = in.dimension(NDIMS - 1); // output filters 161 expanded_dims[1] = in.dimension(NDIMS - 2); // input filters 162 for (int i = 0; i < NDIMS; ++i) { // spatial dimensions 163 expanded_dims[i + 2] = in.dimension(i); 164 } 165 166 out.device(d) = in.reshape(merged_dims) 167 .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0)) 168 .reshape(expanded_dims); 169 } 170}; 171 172template <typename Device, typename T, typename IndexType> 173struct TransformDepth { 174 void operator()(const Device& d, 175 typename TTypes<T, 4, IndexType>::ConstTensor in, 176 const Eigen::DSizes<IndexType, 4>& shuffle, 177 typename TTypes<T, 4, IndexType>::Tensor out) { 178 Eigen::DSizes<IndexType, 3> merged_dims; 179 Eigen::DSizes<IndexType, 4> expanded_dims; 180 Eigen::DSizes<IndexType, 3> new_shuffle; 181 182 // Merge dimensions that won't be shuffled together to speed things up. 183 if (shuffle[1] == 2 && shuffle[2] == 3) { 184 merged_dims[0] = in.dimension(0); 185 merged_dims[1] = in.dimension(1); 186 merged_dims[2] = in.dimension(2) * in.dimension(3); 187 new_shuffle[0] = shuffle[0]; 188 new_shuffle[1] = 2; 189 new_shuffle[2] = shuffle[3]; 190 expanded_dims[0] = in.dimension(shuffle[0]); 191 expanded_dims[1] = in.dimension(2); 192 expanded_dims[2] = in.dimension(3); 193 expanded_dims[3] = in.dimension(shuffle[3]); 194 } else if (shuffle[0] == 2 && shuffle[1] == 3) { 195 merged_dims[0] = in.dimension(0); 196 merged_dims[1] = in.dimension(1); 197 merged_dims[2] = in.dimension(2) * in.dimension(3); 198 new_shuffle[0] = 2; 199 new_shuffle[1] = shuffle[2]; 200 new_shuffle[2] = shuffle[3]; 201 expanded_dims[0] = in.dimension(2); 202 expanded_dims[1] = in.dimension(3); 203 expanded_dims[2] = in.dimension(shuffle[2]); 204 expanded_dims[3] = in.dimension(shuffle[3]); 205 } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 && 206 shuffle[3] == 2) { 207 merged_dims[0] = in.dimension(0); 208 merged_dims[1] = in.dimension(1) * in.dimension(2); 209 merged_dims[2] = in.dimension(3); 210 new_shuffle[0] = 0; 211 new_shuffle[1] = 2; 212 new_shuffle[2] = 1; 213 expanded_dims[0] = in.dimension(0); 214 expanded_dims[1] = in.dimension(3); 215 expanded_dims[2] = in.dimension(1); 216 expanded_dims[3] = in.dimension(2); 217 } else { 218 assert(false && "unexpected shuffle"); 219 } 220 221 out.device(d) = 222 in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims); 223 } 224}; 225 226template <typename Device, typename T, typename IndexType, int NDIMS> 227struct PadInput { 228 void operator()(const Device& d, 229 typename TTypes<T, NDIMS, IndexType>::ConstTensor in, 230 const std::array<int, NDIMS - 2>& padding_left, 231 const std::array<int, NDIMS - 2>& padding_right, 232 typename TTypes<T, NDIMS, IndexType>::Tensor out, 233 TensorFormat format) { 234 Eigen::array<Eigen::IndexPair<IndexType>, NDIMS> padding; 235 padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = {0, 0}; 236 for (int i = 0; i < NDIMS - 2; ++i) { 237 padding[GetTensorDimIndex<NDIMS - 2>(format, '0' + i)] = { 238 padding_left[i], padding_right[i]}; 239 } 240 padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = {0, 0}; 241 out.device(d) = in.pad(padding); 242 } 243}; 244 245// Converts a tensor from: 246// [batch, <spatial>, filters] 247// to: 248// [batch, filters, <spatial>] 249template <typename Device, typename T, int NDIMS> 250struct NHWCToNCHW { 251 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 252 typename TTypes<T, NDIMS>::Tensor out); 253}; 254 255// Converts a tensor from: 256// [batch, filters, <spatial>] 257// to: 258// [batch, <spatial>, filters] 259template <typename Device, typename T, int NDIMS> 260struct NCHWToNHWC { 261 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 262 typename TTypes<T, NDIMS>::Tensor out); 263}; 264 265// Converts a tensor from: 266// [dim0, dim1, dim2] 267// to: 268// [dim0, dim2, dim1] 269template <typename Device, typename T, bool conjugate = false> 270struct SwapDimension1And2InTensor3 { 271 void operator()(const Device& d, const T* in, 272 const gtl::ArraySlice<int64>& input_dims, T* out); 273}; 274 275// Converts a tensor from: 276// [dim0, dim1, dim2] 277// to: 278// [dim2, dim1, dim0] 279template <typename Device, typename T, bool conjugate = false> 280struct SwapDimension0And2InTensor3 { 281 void operator()(const Device& d, const T* in, 282 const gtl::ArraySlice<int64>& input_dims, T* out); 283}; 284 285// Reverses the effect of TransformFilter above. 286template <typename Device, typename T, int NDIMS> 287struct ReverseTransformFilter { 288 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 289 typename TTypes<T, NDIMS>::Tensor out); 290}; 291 292} // namespace functor 293 294template <class T> 295class ConvAlgorithmMap; 296 297template <> 298class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {}; 299} // namespace tensorflow 300 301#endif // TENSORFLOW_KERNELS_CONV_2D_H_ 302