1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_KERNELS_CONV_2D_H_
17#define TENSORFLOW_KERNELS_CONV_2D_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/tensor_types.h"
21#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
22#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
23#include "tensorflow/core/util/tensor_format.h"
24
25namespace tensorflow {
26namespace functor {
27
28// TODO(yangke): revisit these operations and in particular, see if we can
29// combine all of them into just one operation without causing nvcc to
30// timeout.
31template <typename Device, typename T, int Dims, typename IndexType>
32struct ShuffleAndReverse {
33  void operator()(const Device& d,
34                  typename TTypes<T, Dims, IndexType>::ConstTensor input,
35                  const Eigen::DSizes<IndexType, Dims>& order,
36                  const Eigen::array<bool, Dims>& reverse_dims,
37                  typename TTypes<T, Dims, IndexType>::Tensor output) {
38    output.device(d) = input.shuffle(order).reverse(reverse_dims);
39  }
40};
41
42template <typename Device, typename T, int Dims, typename IndexType>
43struct InflatePadAndShuffle {
44  void operator()(
45      const Device& d, typename TTypes<T, Dims, IndexType>::ConstTensor input,
46      const Eigen::DSizes<IndexType, Dims>& strides,
47      const Eigen::array<Eigen::IndexPair<IndexType>, Dims>& pad_dims,
48      const Eigen::DSizes<IndexType, Dims>& order,
49      typename TTypes<T, Dims, IndexType>::Tensor output) {
50    output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order);
51  }
52};
53
54template <typename Device, typename Input, typename Filter, typename Output>
55void SpatialConvolutionFunc(const Device& d, Output output, Input input,
56                            Filter filter, int row_stride, int col_stride,
57                            int row_dilation, int col_dilation,
58                            const Eigen::PaddingType& padding) {
59  // Need to swap row/col when calling Eigen.
60  output.device(d) =
61      Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
62                                col_dilation, row_dilation);
63}
64
65template <typename Device, typename T>
66struct SpatialConvolution {
67  void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
68                  typename TTypes<T, 4>::ConstTensor input,
69                  typename TTypes<T, 4>::ConstTensor filter, int row_stride,
70                  int col_stride, int row_dilation, int col_dilation,
71                  const Eigen::PaddingType& padding) {
72    SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
73                           row_dilation, col_dilation, padding);
74  }
75};
76
77template <typename Device>
78struct SpatialConvolution<Device, Eigen::half> {
79  void operator()(const Device& d,
80                  typename TTypes<Eigen::half, 4>::Tensor output,
81                  typename TTypes<Eigen::half, 4>::ConstTensor input,
82                  typename TTypes<Eigen::half, 4>::ConstTensor filter,
83                  int row_stride, int col_stride, int row_dilation,
84                  int col_dilation, const Eigen::PaddingType& padding) {
85    output.device(d) =
86        Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
87                                  col_stride, row_stride, padding, col_dilation,
88                                  row_dilation)
89            .cast<Eigen::half>();
90  }
91};
92
93template <typename Device, typename T>
94struct SpatialConvolutionBackwardInput {
95  void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
96                  typename TTypes<T, 4>::ConstTensor kernel,
97                  typename TTypes<T, 4>::ConstTensor output_backward,
98                  int row_stride, int col_stride, int row_dilation,
99                  int col_dilation) {
100    // Need to swap row/col when calling Eigen.
101    input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
102        kernel, output_backward, input_backward.dimension(2),
103        input_backward.dimension(1), col_stride, row_stride, col_dilation,
104        row_dilation);
105  }
106};
107
108template <typename Device, typename T>
109struct SpatialConvolutionBackwardFilter {
110  void operator()(const Device& d,
111                  typename TTypes<T, 4>::Tensor kernel_backward,
112                  typename TTypes<T, 4>::ConstTensor input,
113                  typename TTypes<T, 4>::ConstTensor output_backward,
114                  int row_stride, int col_stride, int row_dilation,
115                  int col_dilation) {
116    // Need to swap row/col when calling Eigen.
117    kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
118        input, output_backward, kernel_backward.dimension(1),
119        kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
120        row_dilation);
121  }
122};
123
124// TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h.
125// My initial attempt to do this compiled but failed in the pytest
126// due to a swigdeps error.
127template <typename Device, typename T>
128struct MatMulConvFunctor {
129  // Computes on device "d": out = in0 * in1, where * is matrix
130  // multiplication.
131  void operator()(
132      const Device& d, typename TTypes<T, 2>::Tensor out,
133      typename TTypes<T, 2>::ConstTensor in0,
134      typename TTypes<T, 2>::ConstTensor in1,
135      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
136    out.device(d) = in0.contract(in1, dim_pair);
137  }
138};
139
140// Shuffles a filter tensor from:
141//   [<spatial_dims>, in, out]
142// to:
143//   [out, in, <spatial_dims>]
144template <typename Device, typename T, typename IndexType, int NDIMS>
145struct TransformFilter {
146  void operator()(const Device& d,
147                  typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
148                  typename TTypes<T, NDIMS, IndexType>::Tensor out) {
149    // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
150    // to speed up the shuffle operation.
151    Eigen::DSizes<IndexType, 3> merged_dims;
152    merged_dims[0] = in.dimension(0);  // spatial dimensions
153    for (int i = 1; i < NDIMS - 2; ++i) {
154      merged_dims[0] *= in.dimension(i);
155    }
156    merged_dims[1] = in.dimension(NDIMS - 2);  // input filters
157    merged_dims[2] = in.dimension(NDIMS - 1);  // output filters
158
159    Eigen::DSizes<IndexType, NDIMS> expanded_dims;
160    expanded_dims[0] = in.dimension(NDIMS - 1);  // output filters
161    expanded_dims[1] = in.dimension(NDIMS - 2);  // input filters
162    for (int i = 0; i < NDIMS; ++i) {            // spatial dimensions
163      expanded_dims[i + 2] = in.dimension(i);
164    }
165
166    out.device(d) = in.reshape(merged_dims)
167                        .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
168                        .reshape(expanded_dims);
169  }
170};
171
172template <typename Device, typename T, typename IndexType>
173struct TransformDepth {
174  void operator()(const Device& d,
175                  typename TTypes<T, 4, IndexType>::ConstTensor in,
176                  const Eigen::DSizes<IndexType, 4>& shuffle,
177                  typename TTypes<T, 4, IndexType>::Tensor out) {
178    Eigen::DSizes<IndexType, 3> merged_dims;
179    Eigen::DSizes<IndexType, 4> expanded_dims;
180    Eigen::DSizes<IndexType, 3> new_shuffle;
181
182    // Merge dimensions that won't be shuffled together to speed things up.
183    if (shuffle[1] == 2 && shuffle[2] == 3) {
184      merged_dims[0] = in.dimension(0);
185      merged_dims[1] = in.dimension(1);
186      merged_dims[2] = in.dimension(2) * in.dimension(3);
187      new_shuffle[0] = shuffle[0];
188      new_shuffle[1] = 2;
189      new_shuffle[2] = shuffle[3];
190      expanded_dims[0] = in.dimension(shuffle[0]);
191      expanded_dims[1] = in.dimension(2);
192      expanded_dims[2] = in.dimension(3);
193      expanded_dims[3] = in.dimension(shuffle[3]);
194    } else if (shuffle[0] == 2 && shuffle[1] == 3) {
195      merged_dims[0] = in.dimension(0);
196      merged_dims[1] = in.dimension(1);
197      merged_dims[2] = in.dimension(2) * in.dimension(3);
198      new_shuffle[0] = 2;
199      new_shuffle[1] = shuffle[2];
200      new_shuffle[2] = shuffle[3];
201      expanded_dims[0] = in.dimension(2);
202      expanded_dims[1] = in.dimension(3);
203      expanded_dims[2] = in.dimension(shuffle[2]);
204      expanded_dims[3] = in.dimension(shuffle[3]);
205    } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
206               shuffle[3] == 2) {
207      merged_dims[0] = in.dimension(0);
208      merged_dims[1] = in.dimension(1) * in.dimension(2);
209      merged_dims[2] = in.dimension(3);
210      new_shuffle[0] = 0;
211      new_shuffle[1] = 2;
212      new_shuffle[2] = 1;
213      expanded_dims[0] = in.dimension(0);
214      expanded_dims[1] = in.dimension(3);
215      expanded_dims[2] = in.dimension(1);
216      expanded_dims[3] = in.dimension(2);
217    } else {
218      assert(false && "unexpected shuffle");
219    }
220
221    out.device(d) =
222        in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
223  }
224};
225
226template <typename Device, typename T, typename IndexType, int NDIMS>
227struct PadInput {
228  void operator()(const Device& d,
229                  typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
230                  const std::array<int, NDIMS - 2>& padding_left,
231                  const std::array<int, NDIMS - 2>& padding_right,
232                  typename TTypes<T, NDIMS, IndexType>::Tensor out,
233                  TensorFormat format) {
234    Eigen::array<Eigen::IndexPair<IndexType>, NDIMS> padding;
235    padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = {0, 0};
236    for (int i = 0; i < NDIMS - 2; ++i) {
237      padding[GetTensorDimIndex<NDIMS - 2>(format, '0' + i)] = {
238          padding_left[i], padding_right[i]};
239    }
240    padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = {0, 0};
241    out.device(d) = in.pad(padding);
242  }
243};
244
245// Converts a tensor from:
246//   [batch, <spatial>, filters]
247// to:
248//   [batch, filters, <spatial>]
249template <typename Device, typename T, int NDIMS>
250struct NHWCToNCHW {
251  void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
252                  typename TTypes<T, NDIMS>::Tensor out);
253};
254
255// Converts a tensor from:
256//   [batch, filters, <spatial>]
257// to:
258//   [batch, <spatial>, filters]
259template <typename Device, typename T, int NDIMS>
260struct NCHWToNHWC {
261  void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
262                  typename TTypes<T, NDIMS>::Tensor out);
263};
264
265// Converts a tensor from:
266//   [dim0, dim1, dim2]
267// to:
268//   [dim0, dim2, dim1]
269template <typename Device, typename T, bool conjugate = false>
270struct SwapDimension1And2InTensor3 {
271  void operator()(const Device& d, const T* in,
272                  const gtl::ArraySlice<int64>& input_dims, T* out);
273};
274
275// Converts a tensor from:
276//   [dim0, dim1, dim2]
277// to:
278//   [dim2, dim1, dim0]
279template <typename Device, typename T, bool conjugate = false>
280struct SwapDimension0And2InTensor3 {
281  void operator()(const Device& d, const T* in,
282                  const gtl::ArraySlice<int64>& input_dims, T* out);
283};
284
285// Reverses the effect of TransformFilter above.
286template <typename Device, typename T, int NDIMS>
287struct ReverseTransformFilter {
288  void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
289                  typename TTypes<T, NDIMS>::Tensor out);
290};
291
292}  // namespace functor
293
294template <class T>
295class ConvAlgorithmMap;
296
297template <>
298class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {};
299}  // namespace tensorflow
300
301#endif  // TENSORFLOW_KERNELS_CONV_2D_H_
302