1b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
3b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengLicensed under the Apache License, Version 2.0 (the "License");
4b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengyou may not use this file except in compliance with the License.
5b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengYou may obtain a copy of the License at
6b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
7b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    http://www.apache.org/licenses/LICENSE-2.0
8b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
9b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengUnless required by applicable law or agreed to in writing, software
10b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengdistributed under the License is distributed on an "AS IS" BASIS,
11b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengSee the License for the specific language governing permissions and
13b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fenglimitations under the License.
14b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng==============================================================================*/
15b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
16b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#if GOOGLE_CUDA
17b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
18b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#define EIGEN_USE_GPU
19b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
20b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
22b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/framework/op_kernel.h"
23b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/framework/register_types.h"
24b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/framework/tensor.h"
25b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/framework/tensor_shape.h"
26b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/kernels/bucketize_op.h"
27b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/kernels/cuda_device_array.h"
28b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/platform/logging.h"
29b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/platform/types.h"
30b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#include "tensorflow/core/util/cuda_kernel_helper.h"
31b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
32b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengnamespace tensorflow {
33b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
34b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengtypedef Eigen::GpuDevice GPUDevice;
35b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
36e4532d20973c4c00854492362665317551661c18A. Unique TensorFlowertemplate <typename T, bool useSharedMem>
37b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng__global__ void BucketizeCustomKernel(
38b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    const int32 size_in, const T* in, const int32 size_boundaries,
39b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    CudaDeviceArrayStruct<float> boundaries_array, int32* out) {
40b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  const float* boundaries = GetCudaDeviceArrayOnDevice(&boundaries_array);
41e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower
42e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower  extern __shared__ __align__(sizeof(float)) unsigned char shared_mem[];
43e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower  float* shared_mem_boundaries = reinterpret_cast<float*>(shared_mem);
44e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower
45e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower  if (useSharedMem) {
46e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    int32 lidx = threadIdx.y * blockDim.x + threadIdx.x;
47e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    int32 blockSize = blockDim.x * blockDim.y;
48e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower
49e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    for (int32 i = lidx; i < size_boundaries; i += blockSize) {
50e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower      shared_mem_boundaries[i] = boundaries[i];
51e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    }
52e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower
53e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    __syncthreads();
54e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower
55e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    boundaries = shared_mem_boundaries;
56e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower  }
57e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower
58b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  CUDA_1D_KERNEL_LOOP(i, size_in) {
59b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    T value = in[i];
60b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    int32 bucket = 0;
61b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    int32 count = size_boundaries;
62b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    while (count > 0) {
63b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      int32 l = bucket;
64b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      int32 step = count / 2;
65b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      l += step;
66b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if (!(value < static_cast<T>(boundaries[l]))) {
67b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        bucket = ++l;
68b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        count -= step + 1;
69b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      } else {
70b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        count = step;
71b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      }
72b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    }
73b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    out[i] = bucket;
74b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  }
75b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng}
76b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
77b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengnamespace functor {
78b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
79b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengtemplate <typename T>
80b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengstruct BucketizeFunctor<GPUDevice, T> {
81b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  // PRECONDITION: boundaries_vector must be sorted.
82b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  static Status Compute(OpKernelContext* context,
83b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng                        const typename TTypes<T, 1>::ConstTensor& input,
84b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng                        const std::vector<float>& boundaries_vector,
85b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng                        typename TTypes<int32, 1>::Tensor& output) {
86b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    const GPUDevice& d = context->eigen_device<GPUDevice>();
87b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
88b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    CudaDeviceArrayOnHost<float> boundaries_array(context,
89b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng                                                  boundaries_vector.size());
90b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    TF_RETURN_IF_ERROR(boundaries_array.Init());
91b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    for (int i = 0; i < boundaries_vector.size(); ++i) {
92b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      boundaries_array.Set(i, boundaries_vector[i]);
93b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    }
94b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    TF_RETURN_IF_ERROR(boundaries_array.Finalize());
95b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
96b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d);
97e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    int32 shared_mem_size = sizeof(float) * boundaries_vector.size();
98e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    const int32 kMaxSharedMemBytes = 16384;
99e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    if (shared_mem_size < d.sharedMemPerBlock() &&
100e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower        shared_mem_size < kMaxSharedMemBytes) {
101e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower      BucketizeCustomKernel<T, true>
102e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower          <<<config.block_count, config.thread_per_block, shared_mem_size,
103e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower             d.stream()>>>(input.size(), input.data(), boundaries_vector.size(),
104e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower                           boundaries_array.data(), output.data());
105e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    } else {
106e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower      BucketizeCustomKernel<T, false>
107e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower          <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
108e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower              input.size(), input.data(), boundaries_vector.size(),
109e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower              boundaries_array.data(), output.data());
110e4532d20973c4c00854492362665317551661c18A. Unique TensorFlower    }
111b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    return Status::OK();
112b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  }
113b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng};
114b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng}  // namespace functor
115b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
116b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#define REGISTER_GPU_SPEC(type) \
117b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  template struct functor::BucketizeFunctor<GPUDevice, type>;
118b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
119b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengREGISTER_GPU_SPEC(int32);
120b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengREGISTER_GPU_SPEC(int64);
121b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengREGISTER_GPU_SPEC(float);
122b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengREGISTER_GPU_SPEC(double);
123b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#undef REGISTER_GPU_SPEC
124b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
125b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng}  // namespace tensorflow
126b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
127b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng#endif  // GOOGLE_CUDA
128