1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Specialization of SpaceToBatchFunctor for a GPUDevice.
17
18#if GOOGLE_CUDA
19
20#define EIGEN_USE_GPU
21
22#include "tensorflow/core/kernels/spacetobatch_functor.h"
23
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/util/cuda_kernel_helper.h"
26
27namespace tensorflow {
28
29typedef Eigen::GpuDevice GPUDevice;
30
31// Shape and padding parameters for space-to-batch and batch-to-space conversion
32// GPU kernel.
33template <int NUM_BLOCK_DIMS>
34struct S2BParameters {
35  int32 space_tensor_batch;
36  int32 batch_tensor_shape[NUM_BLOCK_DIMS + 2];
37  int32 space_tensor_spatial_shape[NUM_BLOCK_DIMS];
38  int32 pad_start[NUM_BLOCK_DIMS];
39  int32 block_shape[NUM_BLOCK_DIMS];
40};
41
42// GPU kernel for space-to-batch (if B2S = false) and batch-to-space conversion
43// (if B2S = true).
44//
45// To simplify template implementation given lack of constexpr if, both the
46// input and output pointers are non-const.
47template <typename T, int NUM_BLOCK_DIMS, bool B2S>
48__global__ void S2B(const int32 nthreads, T* space_tensor_ptr,
49                    S2BParameters<NUM_BLOCK_DIMS> args, T* batch_tensor_ptr) {
50  CUDA_1D_KERNEL_LOOP(batch_tensor_idx, nthreads) {
51    int32 remaining_batch_tensor_idx = batch_tensor_idx;
52
53    int32 batch_tensor_pos[NUM_BLOCK_DIMS + 2];
54
55    for (int dim = NUM_BLOCK_DIMS + 1; dim >= 1; --dim) {
56      batch_tensor_pos[dim] =
57          remaining_batch_tensor_idx % args.batch_tensor_shape[dim];
58      remaining_batch_tensor_idx /= args.batch_tensor_shape[dim];
59    }
60    batch_tensor_pos[0] = remaining_batch_tensor_idx;
61
62    int32 remaining_block_idx = batch_tensor_pos[0] / args.space_tensor_batch;
63    int32 space_tensor_idx = batch_tensor_pos[NUM_BLOCK_DIMS + 1];
64    int32 space_tensor_stride = args.batch_tensor_shape[NUM_BLOCK_DIMS + 1];
65    const int32 space_tensor_batch_pos =
66        batch_tensor_pos[0] % args.space_tensor_batch;
67    for (int block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) {
68      int32 offset = remaining_block_idx;
69      if (block_dim > 0) {
70        offset %= args.block_shape[block_dim];
71      }
72      int32 space_tensor_pos =
73          batch_tensor_pos[block_dim + 1] * args.block_shape[block_dim] +
74          offset - args.pad_start[block_dim];
75      if (space_tensor_pos < 0 ||
76          space_tensor_pos >= args.space_tensor_spatial_shape[block_dim]) {
77        if (B2S == false) {
78          // In the space-to-batch case, write zero padding.
79          batch_tensor_ptr[batch_tensor_idx] = static_cast<T>(0);
80        }
81        break;
82      }
83      space_tensor_idx += space_tensor_stride * space_tensor_pos;
84      space_tensor_stride *= args.space_tensor_spatial_shape[block_dim];
85      if (block_dim == 0) {
86        space_tensor_idx += space_tensor_stride * space_tensor_batch_pos;
87        if (B2S == false) {
88          batch_tensor_ptr[batch_tensor_idx] =
89              ldg(space_tensor_ptr + space_tensor_idx);
90        } else {
91          space_tensor_ptr[space_tensor_idx] =
92              ldg(batch_tensor_ptr + batch_tensor_idx);
93        }
94      }
95      remaining_block_idx /= args.block_shape[block_dim];
96    }
97  }
98}
99
100namespace functor {
101template <typename T, int NUM_BLOCK_DIMS, bool B2S>
102struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, B2S> {
103  using SpaceT = typename std::conditional<B2S, T, const T>::type;
104  using BatchT = typename std::conditional<B2S, const T, T>::type;
105  Status operator()(
106      const GPUDevice& d,
107      typename TTypes<SpaceT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
108      const int64 block_shape[NUM_BLOCK_DIMS],
109      const int64 paddings[NUM_BLOCK_DIMS * 2],
110      typename TTypes<BatchT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor) {
111    // Kernel execution fails if number of elements is zero.
112    if (batch_tensor.size() == 0) {
113      return Status::OK();
114    }
115    S2BParameters<NUM_BLOCK_DIMS> args;
116    args.space_tensor_batch = space_tensor.dimension(0);
117    for (int block_dim = 0; block_dim < NUM_BLOCK_DIMS; ++block_dim) {
118      if (block_shape[block_dim] > std::numeric_limits<int32>::max()) {
119        return errors::InvalidArgument("block_shape value exceeds 2^32-1");
120      }
121      args.block_shape[block_dim] = block_shape[block_dim];
122      if (space_tensor.dimension(block_dim + 1) >
123          std::numeric_limits<int32>::max()) {
124        return errors::InvalidArgument("space_tensor dimension exceeds 2^32-1");
125      }
126      args.space_tensor_spatial_shape[block_dim] =
127          space_tensor.dimension(block_dim + 1);
128      if (paddings[block_dim * 2] > std::numeric_limits<int32>::max()) {
129        return errors::InvalidArgument("paddings/crops value exceeds 2^32-1");
130      }
131      args.pad_start[block_dim] = paddings[block_dim * 2];
132    }
133    int64 total_count = 1;
134    for (int dim = 0; dim < NUM_BLOCK_DIMS + 2; ++dim) {
135      args.batch_tensor_shape[dim] = batch_tensor.dimension(dim);
136      total_count *= args.batch_tensor_shape[dim];
137    }
138    if (total_count > std::numeric_limits<int32>::max()) {
139      return errors::InvalidArgument(
140          "number of batch_tensor elements exceeds 2^32-1");
141    }
142    CudaLaunchConfig config =
143        GetCudaLaunchConfig(static_cast<int32>(total_count), d);
144    S2B<T, NUM_BLOCK_DIMS, B2S>
145        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
146            config.virtual_thread_count, const_cast<T*>(space_tensor.data()),
147            args, const_cast<T*>(batch_tensor.data()));
148    return Status::OK();
149  }
150};
151
152// Instantiate.
153#define INSTANTIATE(NUM_BLOCK_DIMS, T)                                      \
154  template struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, false>; \
155  template struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, true>;  \
156  /**/
157
158#define INSTANTIATE_FOR_T(T) \
159  TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(INSTANTIATE, T)
160
161TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_FOR_T)
162
163#undef INSTANTIATE_FOR_T
164#undef INSTANTIATE
165
166}  // end namespace functor
167}  // end namespace tensorflow
168
169#endif  // GOOGLE_CUDA
170