1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#if GOOGLE_CUDA
17
18#define EIGEN_USE_GPU
19
20#include <algorithm>
21#include <array>
22#include <limits>
23#include <utility>
24
25#include "cuda/include/cuda.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/kernels/conv_2d.h"
28#include "tensorflow/core/lib/math/math_util.h"
29#include "tensorflow/core/util/cuda_kernel_helper.h"
30#include "tensorflow/core/util/tensor_format.h"
31
32namespace tensorflow {
33
34typedef Eigen::GpuDevice GPUDevice;
35
36namespace functor {
37namespace {
38template <typename T, bool conjugate>
39struct maybe_conj {
40  __device__ static __inline__ T run(T x) {
41    if (conjugate) {
42      return Eigen::numext::conj(x);
43    } else {
44      return x;
45    }
46  }
47};
48
49// Partial specializations for Cuda types used to store complex numbers.
50template <bool conjugate>
51struct maybe_conj<float2, conjugate> {
52  __device__ static __inline__ float2 run(float2 c) {
53    if (conjugate) {
54      float2 c_conj;
55      c_conj.x = c.x;
56      c_conj.y = -c.y;
57      return c_conj;
58    } else {
59      return c;
60    }
61  }
62};
63
64template <bool conjugate>
65struct maybe_conj<double2, conjugate> {
66  __device__ static __inline__ double2 run(double2 c) {
67    if (conjugate) {
68      double2 c_conj;
69      c_conj.x = c.x;
70      c_conj.y = -c.y;
71      return c_conj;
72    } else {
73      return c;
74    }
75  }
76};
77
78}  // namespace
79
80// TODO(mjanusz): Move this to a shared util file.
81// A simple array that contains data that can be passed between CPU and GPU.
82template <typename T, int IndexCount, T DefaultValue>
83struct Array {
84  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
85    return data[index];
86  }
87  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
88    return data[index];
89  }
90  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() {
91    for (int i = 0; i < IndexCount; i++) {
92      data[i] = DefaultValue;
93    }
94  }
95  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) {
96    data[0] = a0;
97    for (int i = 1; i < IndexCount; i++) {
98      data[i] = DefaultValue;
99    }
100  }
101  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) {
102    data[0] = a0;
103    data[1] = a1;
104    for (int i = 2; i < IndexCount; i++) {
105      data[i] = DefaultValue;
106    }
107  }
108  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) {
109    data[0] = a0;
110    data[1] = a1;
111    data[2] = a2;
112    for (int i = 3; i < IndexCount; i++) {
113      data[i] = DefaultValue;
114    }
115  }
116  EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) {
117    for (int i = 0; i < IndexCount; i++) {
118      data[i] = array[i];
119    }
120  }
121  T data[IndexCount];
122};
123
124// A dimension type with compile-time known size.
125template <int IndexCount>
126struct Dimension : Array<int, IndexCount, 1> {
127  typedef Array<int, IndexCount, 1> Base;
128  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {}
129  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {}
130  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1)
131      : Base(a0, a1) {}
132  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
133      : Base(a0, a1, a2) {}
134  EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array)
135      : Base(array) {}
136};
137
138// An index type with compile-time known size.
139template <int IndexCount>
140struct Index : Array<int, IndexCount, 0> {
141  typedef Array<int, IndexCount, 0> Base;
142  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {}
143  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {}
144  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {}
145  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2)
146      : Base(a0, a1, a2) {}
147};
148
149// A helper function that converts a tensor index into a flat array index.
150template <int IndexCount>
151EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat(
152    const Index<IndexCount>& index, const Dimension<IndexCount>& dims) {
153  int flat_index = index[0];
154  for (int i = 1; i < IndexCount; i++) {
155    flat_index = flat_index * dims[i] + index[i];
156  }
157  return flat_index;
158}
159
160// A helper function that converts a flat array index into a tensor index.
161template <int IndexCount>
162EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
163    int index, const Dimension<IndexCount>& dims) {
164  Index<IndexCount> tensor_index;
165  for (int i = IndexCount - 1; i >= 0; i--) {
166    int new_index = index / dims[i];
167    tensor_index[i] = index - dims[i] * new_index;
168    index = new_index;
169  }
170  return tensor_index;
171}
172
173// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
174template <typename T, bool conjugate = false>
175__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
176                                                  Dimension<3> input_dims,
177                                                  T* output) {
178  Dimension<3> output_dims;
179  output_dims[0] = input_dims[2];
180  output_dims[1] = input_dims[1];
181  output_dims[2] = input_dims[0];
182
183  CUDA_1D_KERNEL_LOOP(index, nthreads) {
184    int output_index = index;
185
186    Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
187
188    Index<3> input_tensor_index;
189    input_tensor_index[0] = output_tensor_index[2];
190    input_tensor_index[1] = output_tensor_index[1];
191    input_tensor_index[2] = output_tensor_index[0];
192
193    int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
194
195    output[output_index] =
196        maybe_conj<T, conjugate>::run(ldg(input + input_index));
197  }
198}
199
200// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
201template <typename T, bool conjugate = false>
202__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
203                                                  Dimension<3> input_dims,
204                                                  T* output) {
205  Dimension<3> output_dims;
206  output_dims[0] = input_dims[0];
207  output_dims[1] = input_dims[2];
208  output_dims[2] = input_dims[1];
209
210  CUDA_1D_KERNEL_LOOP(index, nthreads) {
211    int output_index = index;
212    Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
213
214    Index<3> input_tensor_index;
215    input_tensor_index[0] = output_tensor_index[0];
216    input_tensor_index[1] = output_tensor_index[2];
217    input_tensor_index[2] = output_tensor_index[1];
218
219    int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
220
221    output[output_index] =
222        maybe_conj<T, conjugate>::run(ldg(input + input_index));
223  }
224}
225
226// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
227// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
228//
229// Each thread block operates on a single tile, a rectangle of dimensions
230// TileSizeI x TileSizeJ.
231//
232// In general, for best performance, you should probably set TileSizeI,
233// TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs).
234// With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get
235// the best performance on K40 GPUs.
236template <typename T, int NumThreads, int TileSizeI, int TileSizeJ,
237          bool conjugate = false>
238__global__ void SwapDimension1And2InTensor3UsingTiles(
239    const T* __restrict__ input, Dimension<3> input_dims,
240    T* __restrict__ output) {
241  eigen_assert(blockDim.x == NumThreads);
242  eigen_assert(blockDim.y == 1);
243  eigen_assert(blockDim.z == 1);
244  eigen_assert(gridDim.y == 1);
245  eigen_assert(gridDim.z == 1);
246
247  constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
248  constexpr int WriteRowPerPass = NumThreads / TileSizeI;
249  // One extra line in the inner dimension to avoid share memory bank conflict.
250  __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
251
252  int x = threadIdx.x;
253
254  Dimension<3> output_dims = {
255      input_dims[0],
256      input_dims[2],
257      input_dims[1],
258  };
259
260  Dimension<3> input_dims_in_tiles = {
261      input_dims[0],
262      (input_dims[1] + TileSizeI - 1) / TileSizeI,
263      (input_dims[2] + TileSizeJ - 1) / TileSizeJ,
264  };
265
266  Index<3> input_tile_index =
267      FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
268
269  Index<3> input_tile_origin = {
270      input_tile_index[0],
271      input_tile_index[1] * TileSizeI,
272      input_tile_index[2] * TileSizeJ,
273  };
274
275  int input_origin_flat_index =
276      TensorIndexToFlat(input_tile_origin, input_dims);
277
278  bool full_tile = true;
279  int tile_width = TileSizeJ;
280
281  // Only the last row or column may not have the full size.
282  if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
283    tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ;
284    full_tile &= false;
285  }
286
287  int tile_height = TileSizeI;
288
289  if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
290    tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI;
291    full_tile &= false;
292  }
293
294  // Calculate effective thread number. This ensures that we use the largest
295  // number of threads available to form a regular thread block with no
296  // trailing incomplete lines.
297  constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ;
298
299  if (x < in_effective_thread_num) {
300    // Orient the logical thread block with respect to the input array.
301    // ie. align the contiguous dimension of thread blocks with the contiguous
302    // dimension of the input array.
303    int ti = x / TileSizeJ;
304    int tj = x % TileSizeJ;
305    int input_index = input_origin_flat_index + ti * input_dims[2] + tj;
306    int input_increment = ReadRowPerPass * input_dims[2];
307
308    if (full_tile) {
309#pragma unroll
310      for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) {
311        shared_memory_tile[i_loc][tj] =
312            maybe_conj<T, conjugate>::run(input[input_index]);
313        input_index += input_increment;
314      }
315    } else {
316      if (tj < tile_width) {
317        for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) {
318          shared_memory_tile[i_loc][tj] =
319              maybe_conj<T, conjugate>::run(input[input_index]);
320          input_index += input_increment;
321        }
322      }
323    }
324  }
325
326  __syncthreads();
327
328  Index<3> output_tile_index = {
329      input_tile_index[0],
330      input_tile_index[2],
331      input_tile_index[1],
332  };
333
334  Index<3> output_tile_origin = {
335      output_tile_index[0],
336      output_tile_index[1] * TileSizeJ,
337      output_tile_index[2] * TileSizeI,
338  };
339
340  int output_origin_flat_index =
341      TensorIndexToFlat(output_tile_origin, output_dims);
342
343  constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI;
344
345  if (x < out_effective_thread_num) {
346    // Re-orient the logical thread block with respect to the output array.
347    // ie. align the contiguous dimension of thread blocks with contiguous
348    // dimension of the output array.
349    int ti = x / TileSizeI;
350    int tj = x % TileSizeI;
351    int output_index = output_origin_flat_index + ti * output_dims[2] + tj;
352    int output_increment = WriteRowPerPass * output_dims[2];
353
354    if (full_tile) {
355#pragma unroll
356      for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) {
357        output[output_index] = shared_memory_tile[tj][i_loc];
358        output_index += output_increment;
359      }
360    } else {
361      if (tj < tile_height) {
362        for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) {
363          output[output_index] = shared_memory_tile[tj][i_loc];
364          output_index += output_increment;
365        }
366      }
367    }
368  }
369}
370
371// A Cuda custom kernel that convert input to output, given proper padding on
372// the left and the top. The padded value is zero.
373template <typename T, int NDIMS>
374__global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
375                                         Dimension<NDIMS> input_dims, T* output,
376                                         Dimension<NDIMS> output_dims,
377                                         Dimension<NDIMS - 2> padding_left) {
378  CUDA_1D_KERNEL_LOOP(index, nthreads) {
379    int output_index = index;
380    Index<NDIMS> output_tensor_index =
381        FlatToTensorIndex(output_index, output_dims);
382
383    Index<NDIMS> input_tensor_index;
384    input_tensor_index[0] = output_tensor_index[0];  // batch
385    bool ok = true;
386    for (int i = 1; i < NDIMS - 1; i++) {
387      input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1];
388      ok &=
389          (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
390    }
391    input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1];  // channels
392
393    if (ok) {
394      const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
395      output[output_index] = input[input_index];
396    } else {
397      output[output_index] = T(0);
398    }
399  }
400}
401
402template <typename T, int NDIMS>
403__global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
404                                         Dimension<NDIMS> input_dims, T* output,
405                                         Dimension<NDIMS> output_dims,
406                                         Dimension<NDIMS - 2> padding_left) {
407  CUDA_1D_KERNEL_LOOP(index, nthreads) {
408    int output_index = index;
409    Index<NDIMS> output_tensor_index =
410        FlatToTensorIndex(output_index, output_dims);
411
412    Index<NDIMS> input_tensor_index;
413    input_tensor_index[0] = output_tensor_index[0];  // batch
414    input_tensor_index[1] = output_tensor_index[1];  // channels
415    bool ok = true;
416    for (int i = 2; i < NDIMS; i++) {
417      input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2];
418      ok &=
419          (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
420    }
421
422    if (ok) {
423      const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
424      output[output_index] = input[input_index];
425    } else {
426      output[output_index] = T(0);
427    }
428  }
429}
430
431// A GPU helper function that converts TensorFlow filter format to Cudnn filter
432// format.
433template <typename T, int NDIMS>
434struct TransformFilter<GPUDevice, T, int, NDIMS> {
435  typedef GPUDevice Device;
436  void operator()(const Device& d,
437                  typename TTypes<T, NDIMS, int>::ConstTensor in,
438                  typename TTypes<T, NDIMS, int>::Tensor out) {
439    Dimension<3> combined_dims;
440    combined_dims[0] = in.dimension(0);  // spatial dimensions
441    for (int i = 1; i < NDIMS - 2; i++) {
442      combined_dims[0] *= in.dimension(i);
443    }
444    combined_dims[1] = in.dimension(NDIMS - 2);  // input filters
445    combined_dims[2] = in.dimension(NDIMS - 1);  // output filters
446    CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
447    SwapDimension0And2InTensor3Simple<T>
448        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
449            config.virtual_thread_count, in.data(), combined_dims, out.data());
450  }
451};
452
453// Converts Cudnn filter format back to TensorFlow filter format.
454template <typename T, int NDIMS>
455struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
456  typedef GPUDevice Device;
457  void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
458                  typename TTypes<T, NDIMS>::Tensor out) {
459    Dimension<3> combined_dims;
460    combined_dims[0] = in.dimension(0);  // output filters
461    combined_dims[1] = in.dimension(1);  // input filters
462    combined_dims[2] = in.dimension(2);  // spatial dimensions
463    for (int i = 3; i < NDIMS; ++i) {
464      combined_dims[2] *= in.dimension(i);
465    }
466    CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
467    SwapDimension0And2InTensor3Simple<T>
468        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
469            config.virtual_thread_count, in.data(), combined_dims, out.data());
470  }
471};
472
473// A GPU helper function that converts input tensor to a larger output tensor,
474// given proper padding values. The padded value is zero.
475template <typename T, int NDIMS>
476struct PadInput<GPUDevice, T, int, NDIMS> {
477  typedef GPUDevice Device;
478  void operator()(const Device& d,
479                  typename TTypes<T, NDIMS, int>::ConstTensor in,
480                  const std::array<int, NDIMS - 2>& padding_left,
481                  const std::array<int, NDIMS - 2>& padding_right,
482                  typename TTypes<T, NDIMS, int>::Tensor out,
483                  TensorFormat format) {
484    CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
485    Dimension<NDIMS> input_dims;
486    for (int i = 0; i < NDIMS; ++i) {
487      input_dims[i] = in.dimension(i);
488    }
489    Dimension<NDIMS> output_dims;
490    for (int i = 0; i < NDIMS; ++i) {
491      output_dims[i] = out.dimension(i);
492    }
493
494    const Dimension<NDIMS - 2> padding_left_dim(padding_left);
495
496    if (format == FORMAT_NHWC) {
497      PadInputCustomKernelNHWC<T, NDIMS>
498          <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
499              config.virtual_thread_count, in.data(), input_dims, out.data(),
500              output_dims, padding_left_dim);
501    } else if (format == FORMAT_NCHW) {
502      PadInputCustomKernelNCHW<T, NDIMS>
503          <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
504              config.virtual_thread_count, in.data(), input_dims, out.data(),
505              output_dims, padding_left_dim);
506    } else {
507      LOG(FATAL) << "Invalid data format: " << format;
508    }
509  }
510};
511
512// We want std::equal_to and std::greater, but they're not constexpr until
513// C++14.
514struct EqualTo {
515  constexpr bool operator()(int a, int b) const { return a == b; }
516};
517
518struct GreaterThan {
519  constexpr bool operator()(int a, int b) const { return a > b; }
520};
521
522// For each data type, the tile size possibility frontier denotes the tile size
523// combinations that consume the most computational resources constrained by
524// - number of threads per SM limit,
525// - limit on size of the short dimension (<=15) due to the definition of
526//   narrow matrix,
527// - shared memory limit and
528// - some experimentally determined, type-specific constraint on the product of
529//   two side lengths to increase grid-level parallelism.
530//
531// A tile size combination lies on the frontier if and only if one or more
532// constraint mentioned above is hit. Tile size combinations lying outside this
533// frontier are either not possible, or are slower than the alternatives.
534//
535// It is instrumental to consider, for each data type, two subsets of the
536// corresponding frontier:
537// - long side frontier: the union of the biggest tile size combination for
538//   each legal long side len.
539// - non long side frontier: the frontier set minus the long side frontier.
540//
541// TileSizePossibilityFrontierCheck defines the frontier using only the long
542// side frontier tile size combinations (since one can easily extrapolate
543// the entire frontier from this subset). It serves as a utility function
544// to help us determine where a tile size combination of interest lies with
545// resepect to the frontier.
546template <typename Op>
547constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide,
548                                                int TileShortSide,
549                                                int size_of_t, Op op) {
550  // clang-format off
551
552  return (size_of_t == 16 && ((TileLongSide == 32   && op(TileShortSide, 4))  ||
553                             (TileLongSide == 64   && op(TileShortSide, 4))  ||
554                             (TileLongSide == 128  && op(TileShortSide, 4))  ||
555                             (TileLongSide == 256  && op(TileShortSide, 2)))) ||
556          (size_of_t == 8 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
557                             (TileLongSide == 64   && op(TileShortSide, 15)) ||
558                             (TileLongSide == 128  && op(TileShortSide, 8))  ||
559                             (TileLongSide == 256  && op(TileShortSide, 4))  ||
560                             (TileLongSide == 512  && op(TileShortSide, 2)))) ||
561          (size_of_t == 4 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
562                             (TileLongSide == 64   && op(TileShortSide, 15)) ||
563                             (TileLongSide == 128  && op(TileShortSide, 15)) ||
564                             (TileLongSide == 256  && op(TileShortSide, 8))  ||
565                             (TileLongSide == 512  && op(TileShortSide, 4))  ||
566                             (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
567          (size_of_t == 2 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
568                             (TileLongSide == 64   && op(TileShortSide, 15)) ||
569                             (TileLongSide == 128  && op(TileShortSide, 15)) ||
570                             (TileLongSide == 256  && op(TileShortSide, 8))  ||
571                             (TileLongSide == 512  && op(TileShortSide, 4))  ||
572                             (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
573          (size_of_t == 1 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
574                             (TileLongSide == 64   && op(TileShortSide, 15)) ||
575                             (TileLongSide == 128  && op(TileShortSide, 15)) ||
576                             (TileLongSide == 256  && op(TileShortSide, 8))  ||
577                             (TileLongSide == 512  && op(TileShortSide, 4))  ||
578                             (TileLongSide == 1024 && op(TileShortSide, 2))));
579
580  // clang-format on
581}
582
583constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide,
584                                          int size_of_t) {
585  return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
586                                          size_of_t, EqualTo());
587}
588constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide,
589                                       int size_of_t) {
590  return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
591                                          size_of_t, GreaterThan());
592}
593constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide,
594                                             int TileShortSide, int size_of_t) {
595  // For a tile size combination (longside, shortside), lying on the frontier
596  // implies that (longside, shortside) is on or within the frontier but
597  // (longside*2, shortside) or (longside, shortside+1) is not. With the above
598  // critereon, we simply need to use !TileSizeOnLongSideFrontier to ensure that
599  // it is not on the long side frontier.
600  return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) &&
601         (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) ||
602          TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1,
603                                  size_of_t)) &&
604         !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t);
605}
606
607// Helper function to launch a batch narrow matirx transpose kernel.
608template <typename T, int TileLongSide, int TileShortSide>
609void LaunchBatchNarrowMatrixTransposeKernel(
610    const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count,
611    const T* input, const Dimension<3>& input_dims, T* output) {
612  constexpr int NumThreads = TileLongSide;
613  if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
614    SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
615                                          TileShortSide>
616        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(input, input_dims,
617                                                           output);
618  } else {
619    SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
620                                          TileLongSide>
621        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(input, input_dims,
622                                                           output);
623  }
624}
625
626// Recursive template function to search, in a trial-and-error manner, for the
627// minimum tile size configuration satisfying the requested tile side lengths.
628// An important invariant of this search procedure is that for an unsatisfied
629// request, we always try doubling the long side len first, and only after
630// the request is satisfied for the long side len do we begin incrementing
631// the short side len.
632//
633// We have three specializations of this search function depending on where the
634// current tile size combination lies with respect to the frontier.
635// - It lies within the frontier. If request is not satisfied, for the next tile
636// size combination, we first try doubling the long side len and if that does
637// not work, we then increment the short side len.
638// - It lies on the non long side frontier. If the request is not satisfied, we
639// can only increment the short side len.
640// - It lies on the long side frontier. We launch the kernel without checking if
641// the request is satisfied or not.
642template <typename T, int TileLongSide, int TileShortSide,
643          typename dummy = void>
644struct BatchNarrowMatrixTransposeDispatcher {
645  static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
646                   int total_tiles_count, const T* input,
647                   const Dimension<3>& input_dims, T* output) {
648    static_assert(
649        (TileLongSide & (TileLongSide - 1)) == 0,
650        "The length of the longer side of the tile is always a power of 2.");
651    bool request_satisfied =
652        std::max(tile_size_i, tile_size_j) <= TileLongSide &&
653        std::min(tile_size_i, tile_size_j) <= TileShortSide;
654
655    if (request_satisfied) {
656      LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
657          d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
658          output);
659      return;
660    }
661
662    // If the execution reaches here, then the kernel was not launched; we then
663    // determine whether it is the long side or the short side that falls short
664    // of the request and increase that parameter accordingly.
665    const bool long_side_request_not_satisfied =
666        std::max(tile_size_i, tile_size_j) > TileLongSide;
667
668    if (long_side_request_not_satisfied) {
669      BatchNarrowMatrixTransposeDispatcher<
670          T, TileLongSide * 2, TileShortSide>::DoIt(d, tile_size_i, tile_size_j,
671                                                    total_tiles_count, input,
672                                                    input_dims, output);
673    } else {
674      BatchNarrowMatrixTransposeDispatcher<
675          T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
676                                                    total_tiles_count, input,
677                                                    input_dims, output);
678    }
679  }
680};
681
682template <typename T, int TileLongSide, int TileShortSide>
683struct BatchNarrowMatrixTransposeDispatcher<
684    T, TileLongSide, TileShortSide,
685    typename std::enable_if<TileSizeOnNonLongSideFrontier(
686                                TileLongSide, TileShortSide, sizeof(T)),
687                            void>::type> {
688  static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
689                   int total_tiles_count, const T* input,
690                   const Dimension<3>& input_dims, T* output) {
691    static_assert(
692        (TileLongSide & (TileLongSide - 1)) == 0,
693        "The length of the longer side of the tile is always a power of 2.");
694    bool request_satisfied =
695        std::max(tile_size_i, tile_size_j) <= TileLongSide &&
696        std::min(tile_size_i, tile_size_j) <= TileShortSide;
697
698    if (request_satisfied) {
699      LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
700          d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
701          output);
702      return;
703    }
704
705    // If the execution reaches here, then the kernel was not launched; since
706    // we are on the non long side frontier, we increment the short dimension
707    // and try again.
708    BatchNarrowMatrixTransposeDispatcher<
709        T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
710                                                  total_tiles_count, input,
711                                                  input_dims, output);
712  }
713};
714
715template <typename T, int TileLongSide, int TileShortSide>
716struct BatchNarrowMatrixTransposeDispatcher<
717    T, TileLongSide, TileShortSide,
718    typename std::enable_if<TileSizeOnLongSideFrontier(
719                                TileLongSide, TileShortSide, sizeof(T)),
720                            void>::type> {
721  static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
722                   int total_tiles_count, const T* input,
723                   const Dimension<3>& input_dims, T* output) {
724    static_assert(
725        (TileLongSide & (TileLongSide - 1)) == 0,
726        "The length of the longer side of the tile is always a power of 2.");
727
728    LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
729        d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
730        output);
731  }
732};
733
734// This function tries to recover, in a brute force way, the frontier defined in
735// TileSizePossibilityFrontierCheck as a vector of tile size combinations lying
736// on the long side frontier. This vector is sufficient to determine the entire
737// frontier.
738//
739// Note that if one changes the frontier definition in
740// TileSizePossibilityFrontierCheck and forgets to set the largest short
741// side len of the largest legal long side len to 2, this function will fail
742// and crash the program.
743template <int SizeOfT>
744const std::vector<std::pair<int, int>>& GetTileSizesFrontier() {
745  static_assert(
746      SizeOfT <= 16,
747      "Currently, only data types of sizes 16 bytes or less are supported.");
748  static_assert((SizeOfT & (SizeOfT - 1)) == 0,
749                "Data types must have sizes that are powers of 2.");
750
751  // Expensive work to populate sizes, lazily run in a thread-safe
752  // manner the first time GetTileSizesFrontier<N> is called.
753  static auto* frontier = [] {
754    auto* frontier = new std::vector<std::pair<int, int>>();
755    const int kMaxLongSideLen = 1024;
756    const int kMaxShortSideLen = 15;
757    for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
758      for (int short_side = 2; short_side <= kMaxShortSideLen;
759           short_side += 1) {
760        if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) {
761          // The current combination lies on the frontier, thus we
762          // add it to the frontier definition.
763          frontier->push_back(std::make_pair(long_side, short_side));
764
765          // The long side length is the largest one allowed iff its
766          // corresponding short side length is 2.
767          if (short_side == 2) return frontier;
768
769          // We have exhausted all the possibilities in the frontier
770          // with the given long side length.
771          break;
772        }
773      }
774    }
775    LOG(FATAL)
776        << "The corresponding short side length of the largest long side "
777           "length has to be 2.";
778  }();
779  return *frontier;
780}
781
782// Helper structs to help determine which data type to use given the size of
783// the matrix data type. A transpose of elements of size N will use a kernel
784// which operates on an array of TransposeElemType<N>::type.
785template <int ElemBytes>
786struct TransposeElemType;
787template <>
788struct TransposeElemType<1> {
789  using type = uint8;
790};
791template <>
792struct TransposeElemType<2> {
793  using type = uint16;
794};
795template <>
796struct TransposeElemType<4> {
797  using type = uint32;
798};
799template <>
800struct TransposeElemType<8> {
801  using type = uint64;
802};
803template <>
804struct TransposeElemType<16> {
805  using type = float4;
806};
807
808// A helper function to make RunSwapDimension1And2InTensor3 concise. This
809// helper function looks at the data type and input matrix sizes and decides
810// the thread numbers and tile sizes to use.
811template <typename T, bool conjugate = false>
812void SwapDimension1And2InTensor3WithNarrowMatrices(
813    const GPUDevice& d, const T* input, const Dimension<3>& input_dims,
814    T* output, const int kMinDimensionToUseTiles) {
815  // Get available tile sizes here for the data type requested:
816  const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>();
817
818  int tile_long_side_len = 0;
819  int tile_short_side_len = 0;
820  float lowest_cost = std::numeric_limits<float>::max();
821  int data_long_side = std::max(input_dims[1], input_dims[2]);
822
823  for (auto tile_size_pair : tile_spec) {
824    int proposed_tile_long_side_len = tile_size_pair.first;
825
826    // Number of threads that will not be doing anything useful when reading
827    // the matrix because the thread block size is bigger than the data block
828    // size.
829    int num_wasted_threads =
830        data_long_side - MathUtil::FloorOfRatio<int>(
831                             data_long_side, proposed_tile_long_side_len) *
832                             proposed_tile_long_side_len;
833
834    int num_full_tiles = MathUtil::FloorOfRatio<int>(
835        data_long_side, proposed_tile_long_side_len);
836
837    float cost = 0;
838
839    // However, if we can execute two or more full tiles, then we gladly
840    // accept any number of wasted threads and ignore its cost.
841    if (num_full_tiles <= 1) cost = num_wasted_threads;
842
843    // Using less than or equal to here because given the same cost, we
844    // would like to launch as many threads as possible.
845    if (cost <= lowest_cost) {
846      tile_long_side_len = proposed_tile_long_side_len;
847      tile_short_side_len = tile_size_pair.second;
848      lowest_cost = cost;
849    }
850  }
851
852  // Request tile sizes such that the longer side of threadblock aligns with
853  // the longer side of input data block to maximize read throughput.
854  // The ideal tile shape is one where the length of the shorter side of the
855  // tile is equal to the length of the shorter side of the input matrix.
856  int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles
857                                  ? tile_long_side_len
858                                  : input_dims[1];
859  int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles
860                                  ? input_dims[2]
861                                  : tile_long_side_len;
862
863  // Truncate the shorter size requested according to the manual limit set in
864  // tile_spec to make sure that we do not launch configurations violating
865  // hardware limits.
866  requested_tile_size_i =
867      requested_tile_size_i == tile_long_side_len
868          ? tile_long_side_len
869          : std::min(requested_tile_size_i, tile_short_side_len);
870  requested_tile_size_j =
871      requested_tile_size_j == tile_long_side_len
872          ? tile_long_side_len
873          : std::min(requested_tile_size_j, tile_short_side_len);
874
875  Dimension<3> input_dims_in_tiles = {
876      input_dims[0],
877      MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i),
878      MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j),
879  };
880
881  int total_tiles_count =
882      input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
883
884  using ElemType = typename TransposeElemType<sizeof(T)>::type;
885  static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment.");
886  BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2>::DoIt(
887      d, requested_tile_size_i, requested_tile_size_j, total_tiles_count,
888      reinterpret_cast<const ElemType*>(input), input_dims,
889      reinterpret_cast<ElemType*>(output));
890}
891
892// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
893// 3D tensor. It looks at the shape of the incoming data, and decides the best
894// strategy to launch.
895template <typename T, bool conjugate = false>
896void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
897                                    const Dimension<3>& input_dims, T* output) {
898  // If both dimensions are not trivial, use tiles for the actual swapping.
899  // If one dimension is trivial, use SmallDim kernel for swapping.
900  // Otherwise, the trivial swapping relying on the ldg cache is more efficient.
901  static const int kMinDimensionToUseTiles = 16;
902  static const int kMinDimensionToUseRectTiles = 96;
903
904  bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles &&
905                      input_dims[2] >= kMinDimensionToUseTiles;
906  bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles ||
907                       input_dims[2] >= kMinDimensionToUseRectTiles;
908  if (large_matrix) {
909    // We get best performance when kTileSize is the number of threads in a warp
910    // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
911    // threads.
912    constexpr int kTileSize = 32;
913    constexpr int kNumThreads = 256;
914
915    Dimension<3> input_dims_in_tiles = {
916        input_dims[0],
917        MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize),
918        MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize),
919    };
920
921    int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
922                            input_dims_in_tiles[2];
923    SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, kTileSize,
924                                          conjugate>
925        <<<total_tiles_count, kNumThreads, 0, d.stream()>>>(input, input_dims,
926                                                            output);
927
928  } else if (narrow_matrix) {
929    SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>(
930        d, input, input_dims, output, kMinDimensionToUseTiles);
931  } else {
932    int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
933    CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
934    SwapDimension1And2InTensor3Simple<T, conjugate>
935        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
936            config.virtual_thread_count, input, input_dims, output);
937  }
938}
939
940// A GPU helper functor that does general dimension 1 and 2 switch for 3D
941// tensor.
942template <typename T, bool conjugate>
943struct SwapDimension1And2InTensor3<GPUDevice, T, conjugate> {
944  typedef GPUDevice Device;
945  void operator()(const Device& d, const T* in,
946                  const gtl::ArraySlice<int64>& combined_dims, T* out) {
947    Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
948                               static_cast<int>(combined_dims[1]),
949                               static_cast<int>(combined_dims[2])};
950    RunSwapDimension1And2InTensor3<T, conjugate>(d, in, input_dims, out);
951  }
952};
953
954// A GPU helper functor that does general dimension 0 and 2 switch for 3D
955// tensor.
956template <typename T, bool conjugate>
957struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
958  typedef GPUDevice Device;
959  void operator()(const Device& d, const T* in,
960                  const gtl::ArraySlice<int64>& combined_dims, T* out) {
961    Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
962                               static_cast<int>(combined_dims[1]),
963                               static_cast<int>(combined_dims[2])};
964    size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
965    CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
966    SwapDimension0And2InTensor3Simple<T, conjugate>
967        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
968            config.virtual_thread_count, in, input_dims, out);
969  }
970};
971
972// A GPU helper functor that converts NHWC TensorFlow data format to
973// NCHW format that is accepted by Cudnn.
974template <typename T, int NDIMS>
975struct NHWCToNCHW<GPUDevice, T, NDIMS> {
976  typedef GPUDevice Device;
977  void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
978                  typename TTypes<T, NDIMS>::Tensor out) {
979    Dimension<3> combined_dims;
980    combined_dims[0] = in.dimension(0);  // N (batch)
981    combined_dims[1] = in.dimension(1);  // spatial dimensions (HW)
982    for (int i = 2; i < NDIMS - 1; ++i) {
983      combined_dims[1] *= in.dimension(i);
984    }
985    combined_dims[2] = in.dimension(NDIMS - 1);  // C (channels)
986    RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
987  }
988};
989
990// A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow
991// Format.
992template <typename T, int NDIMS>
993struct NCHWToNHWC<GPUDevice, T, NDIMS> {
994  typedef GPUDevice Device;
995  void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
996                  typename TTypes<T, NDIMS>::Tensor out) {
997    Dimension<3> combined_dims;
998    combined_dims[0] = in.dimension(0);  // N (batch)
999    combined_dims[1] = in.dimension(1);  // C (channel)
1000    combined_dims[2] = in.dimension(2);  // spatial dimensions (HW)
1001    for (int i = 3; i < NDIMS; ++i) {
1002      combined_dims[2] *= in.dimension(i);
1003    }
1004    RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
1005  }
1006};
1007
1008}  // namespace functor
1009
1010template struct functor::ShuffleAndReverse<GPUDevice, float, 4, int>;
1011template struct functor::ShuffleAndReverse<GPUDevice, Eigen::half, 4, int>;
1012
1013template struct functor::ShuffleAndReverse<GPUDevice, float, 4,
1014                                           Eigen::DenseIndex>;
1015template struct functor::ShuffleAndReverse<GPUDevice, Eigen::half, 4,
1016                                           Eigen::DenseIndex>;
1017
1018template struct functor::TransformDepth<GPUDevice, float, int>;
1019template struct functor::TransformDepth<GPUDevice, Eigen::half, int>;
1020
1021template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint8>;
1022template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint16>;
1023template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint32>;
1024template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint64>;
1025template struct functor::SwapDimension1And2InTensor3<GPUDevice, float4>;
1026template struct functor::SwapDimension1And2InTensor3<GPUDevice, float2,
1027                                                     /*conjugate=*/true>;
1028template struct functor::SwapDimension1And2InTensor3<GPUDevice, double2,
1029                                                     /*conjugate=*/true>;
1030
1031template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint8>;
1032template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint16>;
1033template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint32>;
1034template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint64>;
1035template struct functor::SwapDimension0And2InTensor3<GPUDevice, float4>;
1036template struct functor::SwapDimension0And2InTensor3<GPUDevice, float2,
1037                                                     /*conjugate=*/true>;
1038template struct functor::SwapDimension0And2InTensor3<GPUDevice, double2,
1039                                                     /*conjugate=*/true>;
1040
1041// For 2d ops.
1042template struct functor::TransformFilter<GPUDevice, float, int, 4>;
1043template struct functor::TransformFilter<GPUDevice, Eigen::half, int, 4>;
1044
1045template struct functor::ReverseTransformFilter<GPUDevice, float, 4>;
1046template struct functor::ReverseTransformFilter<GPUDevice, Eigen::half, 4>;
1047
1048template struct functor::NHWCToNCHW<GPUDevice, double, 4>;
1049template struct functor::NHWCToNCHW<GPUDevice, float, 4>;
1050template struct functor::NHWCToNCHW<GPUDevice, Eigen::half, 4>;
1051
1052template struct functor::NCHWToNHWC<GPUDevice, double, 4>;
1053template struct functor::NCHWToNHWC<GPUDevice, float, 4>;
1054template struct functor::NCHWToNHWC<GPUDevice, Eigen::half, 4>;
1055
1056template struct functor::PadInput<GPUDevice, int, int, 4>;
1057template struct functor::PadInput<GPUDevice, float, int, 4>;
1058template struct functor::PadInput<GPUDevice, Eigen::half, int, 4>;
1059
1060// For 3d ops.
1061template struct functor::TransformFilter<GPUDevice, float, int, 5>;
1062template struct functor::TransformFilter<GPUDevice, Eigen::half, int, 5>;
1063
1064template struct functor::ReverseTransformFilter<GPUDevice, float, 5>;
1065template struct functor::ReverseTransformFilter<GPUDevice, Eigen::half, 5>;
1066
1067template struct functor::NHWCToNCHW<GPUDevice, float, 5>;
1068template struct functor::NHWCToNCHW<GPUDevice, Eigen::half, 5>;
1069
1070template struct functor::NCHWToNHWC<GPUDevice, float, 5>;
1071template struct functor::NCHWToNHWC<GPUDevice, Eigen::half, 5>;
1072
1073template struct functor::PadInput<GPUDevice, float, int, 5>;
1074template struct functor::PadInput<GPUDevice, Eigen::half, int, 5>;
1075
1076}  // namespace tensorflow
1077
1078#endif  // GOOGLE_CUDA
1079