1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#if GOOGLE_CUDA 17 18#define EIGEN_USE_GPU 19 20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21#include "tensorflow/core/kernels/ops_util.h" 22#include "tensorflow/core/kernels/transpose_functor.h" 23#include "tensorflow/core/util/cuda_kernel_helper.h" 24 25// TODO(yangzihao): Remove the dependency of conv_2d.h once we move all 26// GPU util functions and transpose kernels into separate files. 27#include "tensorflow/core/kernels/conv_2d.h" 28 29typedef Eigen::GpuDevice GPUDevice; 30 31namespace tensorflow { 32namespace internal { 33 34template <typename T, bool conjugate> 35__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf, 36 const int32 ndims, T* dst) { 37 const int32* in_strides = buf; 38 const int32* out_strides = buf + ndims; 39 const int32* perm = buf + ndims * 2; 40 CUDA_1D_KERNEL_LOOP(o_idx, nthreads) { 41 int32 i_idx = 0; 42 int32 t = o_idx; 43 for (int32 i = 0; i < ndims; ++i) { 44 const int32 ratio = t / out_strides[i]; 45 t -= ratio * out_strides[i]; 46 i_idx += ratio * in_strides[perm[i]]; 47 } 48 if (conjugate) { 49 dst[o_idx] = Eigen::numext::conj(ldg(src + i_idx)); 50 } else { 51 dst[o_idx] = ldg(src + i_idx); 52 } 53 } 54} 55 56template <typename T, bool conjugate> 57void TransposeSimple(const GPUDevice& d, const Tensor& in, 58 const gtl::ArraySlice<int32> perm, Tensor* out) { 59 // Ensures we can use 32-bit index. 60 const int64 nelem = in.NumElements(); 61 CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU"; 62 // Pack strides and permutation into one buffer. 63 const int32 ndims = in.dims(); 64 gtl::InlinedVector<int32, 24> host_buf(ndims * 3); 65 gtl::InlinedVector<int32, 8> in_strides = ComputeStride<int32>(in.shape()); 66 gtl::InlinedVector<int32, 8> out_strides = ComputeStride<int32>(out->shape()); 67 // Dimension permutation. 68 for (int i = 0; i < ndims; ++i) { 69 host_buf[i] = in_strides[i]; 70 host_buf[ndims + i] = out_strides[i]; 71 host_buf[ndims * 2 + i] = perm[i]; 72 } 73 // Copies the input strides, output strides and permutation to the device. 74 auto num_bytes = sizeof(int64) * host_buf.size(); 75 auto dev_buf = d.allocate(num_bytes); 76 // NOTE: host_buf is not allocated by CudaHostAllocator, and 77 // therefore we are doing a sync copy effectively. 78 d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes); 79 // Launch kernel to q[...] = p[...]. 80 const T* p = reinterpret_cast<const T*>(in.tensor_data().data()); 81 T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data()))); 82 CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); 83 TransposeKernel<T, conjugate> 84 <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>( 85 cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf), 86 ndims, q); 87 // Safe to deallocate immediately after the kernel launch. 88 d.deallocate(dev_buf); 89} 90 91// TransposeUsingTile tries to reduce the dimension of the input tensor to 3 and 92// then call special kernels to swap either dimension 1 and dimension 2 or 93// dimension 0 and dimension 2. It returns true if the operation is success, 94// false otherwise. 95template <typename T, bool conjugate = false> 96struct TransposeUsingTile { 97 static bool run(const Eigen::GpuDevice& d, const Tensor& in, 98 const gtl::ArraySlice<int32> perm, Tensor* out) { 99 // First try to reduce the dimensions of the input tensor. 100 TransposePermsVec new_perm; 101 TransposeDimsVec new_dims; 102 ReduceTransposeDimensions(in.shape(), perm, &new_perm, &new_dims); 103 104 // Only use special GPU kernel when dimension is 2 or 3. 105 int dims = new_dims.size(); 106 if (dims < 2 || dims > 3) return false; 107 auto in_data = reinterpret_cast<const T*>(in.tensor_data().data()); 108 auto out_data = 109 reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())); 110 switch (dims) { 111 case 2: 112 if (new_perm[0] == 1 && new_perm[1] == 0) { 113 // Add the first dimension size as 1. 114 new_dims.insert(new_dims.begin(), 1); 115 tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T, 116 conjugate>()( 117 d, in_data, new_dims, out_data); 118 return true; 119 } 120 break; 121 case 3: 122 if (new_perm == TransposePermsVec({0, 2, 1})) { 123 tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T, 124 conjugate>()( 125 d, in_data, new_dims, out_data); 126 return true; 127 } else if (new_perm == TransposePermsVec({2, 1, 0})) { 128 tensorflow::functor::SwapDimension0And2InTensor3<GPUDevice, T, 129 conjugate>()( 130 d, in_data, new_dims, out_data); 131 return true; 132 } else { 133 // do not handle other 3D permutations 134 return false; 135 } 136 break; 137 default: 138 return false; 139 } 140 return false; 141 } 142}; 143 144template <bool conjugate> 145struct TransposeUsingTile<complex64, conjugate> { 146 static bool run(const Eigen::GpuDevice& d, const Tensor& in, 147 const gtl::ArraySlice<int32> perm, Tensor* out) { 148 if (!conjugate) { 149 return TransposeUsingTile<uint64>::run(d, in, perm, out); 150 } else { 151 return TransposeUsingTile<float2, true>::run(d, in, perm, out); 152 } 153 } 154}; 155 156template <bool conjugate> 157struct TransposeUsingTile<complex128, conjugate> { 158 static bool run(const Eigen::GpuDevice& d, const Tensor& in, 159 const gtl::ArraySlice<int32> perm, Tensor* out) { 160 if (!conjugate) { 161 return TransposeUsingTile<float4>::run(d, in, perm, out); 162 } else { 163 return TransposeUsingTile<double2, true>::run(d, in, perm, out); 164 } 165 } 166}; 167 168} // namespace internal 169 170// Transpose kernel specialized for GPU Device. 171template <typename T, bool conjugate> 172struct Transpose<GPUDevice, T, conjugate> { 173 static void run(const GPUDevice& d, const Tensor& in, 174 const gtl::ArraySlice<int32> perm, Tensor* out) { 175 switch (in.dims()) { 176 case 2: 177 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 178 out)) { 179 internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, conjugate, 180 out); 181 } 182 break; 183 case 3: 184 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 185 out)) { 186 internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, conjugate, 187 out); 188 } 189 break; 190 case 4: 191 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 192 out)) { 193 internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, conjugate, 194 out); 195 } 196 break; 197 case 5: 198 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 199 out)) { 200 internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, conjugate, 201 out); 202 } 203 break; 204 case 6: 205 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 206 out)) { 207 internal::TransposeUsingEigen<GPUDevice, T, 6>(d, in, perm, conjugate, 208 out); 209 } 210 break; 211 case 7: 212 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 213 out)) { 214 internal::TransposeUsingEigen<GPUDevice, T, 7>(d, in, perm, conjugate, 215 out); 216 } 217 break; 218 case 8: 219 if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 220 out)) { 221 internal::TransposeUsingEigen<GPUDevice, T, 8>(d, in, perm, conjugate, 222 out); 223 } 224 break; 225 default: 226 internal::TransposeSimple<T, conjugate>(d, in, perm, out); 227 break; 228 } 229 } 230}; 231 232template <bool conjugate> 233struct Transpose<GPUDevice, string, conjugate> { 234 static void run(const GPUDevice& d, const Tensor& in, 235 const gtl::ArraySlice<int32> perm, Tensor* out) { 236 LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU."; 237 } 238}; 239 240// Explicit instantiation. 241template struct Transpose<GPUDevice, string, false>; 242 243template <> 244Status DoTranspose(const GPUDevice& device, const Tensor& in, 245 const gtl::ArraySlice<int32> perm, Tensor* out) { 246 return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, out); 247} 248template <> 249Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in, 250 const gtl::ArraySlice<int32> perm, Tensor* out) { 251 return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, out); 252} 253template <> 254Status DoMatrixTranspose(const GPUDevice& device, const Tensor& in, 255 Tensor* out) { 256 return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, out); 257} 258template <> 259Status DoConjugateMatrixTranspose(const GPUDevice& device, const Tensor& in, 260 Tensor* out) { 261 return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, out); 262} 263 264} // namespace tensorflow 265#endif // GOOGLE_CUDA 266