1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#if GOOGLE_CUDA
17
18#define EIGEN_USE_GPU
19
20#include <memory>
21#include <vector>
22
23#include "tensorflow/core/framework/bfloat16.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor_types.h"
26#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
27#include "tensorflow/core/util/cuda_kernel_helper.h"
28
29namespace tensorflow {
30
31typedef Eigen::GpuDevice GPUDevice;
32
33namespace {
34
35template <typename T, typename IntType>
36__global__ void concat_fixed_kernel(
37    CudaDeviceArrayStruct<const T*> input_ptr_data, int split_size,
38    int total_rows, int total_cols, T* output) {
39  const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data);
40  IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
41
42  for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
43    IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
44
45    IntType split = gidx / split_size;
46    const T* input_ptr = input_ptrs[split];
47    IntType col_offset = gidx % split_size;
48#pragma unroll
49    for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) {
50      output[gidy * total_cols + gidx] =
51          input_ptr[gidy * split_size + col_offset];
52    }
53  }
54}
55
56}  // end namespace
57
58// cannot be in anonymous namespace due to extern shared memory
59template <typename T, typename IntType, bool useSmem>
60__global__ void concat_variable_kernel(
61    CudaDeviceArrayStruct<const T*> input_ptr_data,
62    CudaDeviceArrayStruct<IntType> output_scan, IntType total_rows,
63    IntType total_cols, T* output) {
64  const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data);
65  IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan);
66
67  // do upper_bound on col to find which pointer we should be using
68  IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
69  IntType num_inputs = input_ptr_data.size;
70
71  // verbose declaration needed due to template
72  extern __shared__ __align__(sizeof(T)) unsigned char smem[];
73  IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
74
75  if (useSmem) {
76    IntType lidx = threadIdx.y * blockDim.x + threadIdx.x;
77    IntType blockSize = blockDim.x * blockDim.y;
78
79    for (IntType i = lidx; i < output_scan.size; i += blockSize) {
80      smem_col_scan[i] = col_scan[i];
81    }
82
83    __syncthreads();
84
85    col_scan = smem_col_scan;
86  }
87
88  // do an initial binary search and then scan linearly from there
89  // works well when there are many small segments and when the
90  // segments are much longer
91  IntType segment =
92      cuda_helper::upper_bound<IntType>(col_scan, num_inputs, gidx) - 1;
93
94  IntType curr_offset = col_scan[segment];
95  IntType curr_segment = segment;
96  for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
97    IntType curr_col_offset;
98    while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) {
99      curr_offset = curr_col_offset;
100      ++curr_segment;
101    }
102
103    IntType local_col = gidx - curr_offset;
104    IntType segment_width = curr_col_offset - curr_offset;
105    const T* input_ptr = input_ptrs[curr_segment];
106
107    IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
108    for (; gidy < total_rows; gidy += blockDim.y * gridDim.y)
109      output[gidy * total_cols + gidx] =
110          input_ptr[gidy * segment_width + local_col];
111  }
112}
113
114template <typename T, typename IntType>
115void ConcatGPUSlice(
116    const Eigen::GpuDevice& gpu_device,
117    const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
118        inputs_flat,
119    typename TTypes<T, 2>::Matrix* output) {
120  Eigen::array<IntType, 2> offset{0, 0};
121  for (int i = 0; i < inputs_flat.size(); ++i) {
122    Eigen::array<IntType, 2> size;
123    size[0] = inputs_flat[i]->dimension(0);
124    size[1] = inputs_flat[i]->dimension(1);
125    if (std::is_same<IntType, int32>::value) {
126      To32Bit(*output).slice(offset, size).device(gpu_device) =
127          To32Bit(*inputs_flat[i]);
128    } else {
129      output->slice(offset, size).device(gpu_device) = *inputs_flat[i];
130    }
131
132    offset[1] += size[1];
133  }
134}
135
136template <typename T, typename IntType>
137void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
138                   const CudaDeviceArrayStruct<const T*>& input_ptrs,
139                   const CudaDeviceArrayStruct<IntType>& output_scan,
140                   bool fixed_size, int split_size,
141                   typename TTypes<T, 2>::Matrix* output) {
142  auto config = GetCuda2DLaunchConfig(output->dimension(1),
143                                      output->dimension(0), gpu_device);
144
145  if (fixed_size) {
146    concat_fixed_kernel<T, IntType>
147        <<<config.block_count, config.thread_per_block, 0,
148           gpu_device.stream()>>>(input_ptrs, split_size, output->dimension(0),
149                                  output->dimension(1), output->data());
150  } else {
151    IntType smem_max = gpu_device.sharedMemPerBlock();
152    IntType smem_usage = output_scan.size * sizeof(IntType);
153    // performance crossover is less than using maximum available shared memory
154    // on most processors
155    // possibly due to decreasing occupancy
156    // 4096 inputs is a lot, most code will take the smem path
157    const int32 kMaxSmemBytesPerformance = 16384;
158    if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
159      concat_variable_kernel<T, IntType, true>
160          <<<config.block_count, config.thread_per_block, smem_usage,
161             gpu_device.stream()>>>(input_ptrs, output_scan,
162                                    output->dimension(0), output->dimension(1),
163                                    output->data());
164    else
165      concat_variable_kernel<T, IntType, false>
166          <<<config.block_count, config.thread_per_block, 0,
167             gpu_device.stream()>>>(input_ptrs, output_scan,
168                                    output->dimension(0), output->dimension(1),
169                                    output->data());
170  }
171}
172
173#define REGISTER_GPUCONCAT32(T)                                               \
174  template void ConcatGPUSlice<T, int32>(                                     \
175      const Eigen::GpuDevice& gpu_device,                                     \
176      const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
177          inputs_flat,                                                        \
178      typename TTypes<T, 2>::Matrix* output);
179
180#define REGISTER_GPUCONCAT64(T)                                               \
181  template void ConcatGPUSlice<T, int64>(                                     \
182      const Eigen::GpuDevice& gpu_device,                                     \
183      const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
184          inputs_flat,                                                        \
185      typename TTypes<T, 2>::Matrix* output);
186
187#define REGISTER_GPU32(T)                                               \
188  template void ConcatGPUImpl<T, int32>(                                \
189      const Eigen::GpuDevice& d,                                        \
190      const CudaDeviceArrayStruct<const T*>& input_ptrs,                \
191      const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
192      int split_size, typename TTypes<T, 2>::Matrix* output);
193
194#define REGISTER_GPU64(T)                                               \
195  template void ConcatGPUImpl<T, int64>(                                \
196      const Eigen::GpuDevice& d,                                        \
197      const CudaDeviceArrayStruct<const T*>& input_ptrs,                \
198      const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
199      int split_size, typename TTypes<T, 2>::Matrix* output);
200
201TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
202TF_CALL_complex64(REGISTER_GPUCONCAT32);
203TF_CALL_complex128(REGISTER_GPUCONCAT32);
204TF_CALL_int64(REGISTER_GPUCONCAT32);
205REGISTER_GPUCONCAT32(bfloat16);
206REGISTER_GPUCONCAT32(bool);
207
208TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
209TF_CALL_complex64(REGISTER_GPUCONCAT64);
210TF_CALL_complex128(REGISTER_GPUCONCAT64);
211TF_CALL_int64(REGISTER_GPUCONCAT64);
212REGISTER_GPUCONCAT64(bfloat16);
213REGISTER_GPUCONCAT64(bool);
214
215TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
216TF_CALL_complex64(REGISTER_GPU32);
217TF_CALL_complex128(REGISTER_GPU32);
218TF_CALL_int64(REGISTER_GPU32);
219REGISTER_GPU32(bfloat16);
220REGISTER_GPU32(bool);
221
222TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
223TF_CALL_complex64(REGISTER_GPU64);
224TF_CALL_complex128(REGISTER_GPU64);
225TF_CALL_int64(REGISTER_GPU64);
226REGISTER_GPU64(bfloat16);
227REGISTER_GPU64(bool);
228
229#undef REGISTER_GPUCONCAT32
230#undef REGISTER_GPUCONCAT64
231#undef REGISTER_GPU32
232#undef REGISTER_GPU64
233
234}  // end namespace tensorflow
235
236#endif  // GOOGLE_CUDA
237