1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_GPU 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 206804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 2190d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai#include "tensorflow/core/kernels/ops_util.h" 22d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower#include "tensorflow/core/kernels/transpose_functor.h" 236804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower#include "tensorflow/core/util/cuda_kernel_helper.h" 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 25ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// TODO(yangzihao): Remove the dependency of conv_2d.h once we move all 26ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// GPU util functions and transpose kernels into separate files. 27ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang#include "tensorflow/core/kernels/conv_2d.h" 28ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang 29d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertypedef Eigen::GpuDevice GPUDevice; 30d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 326804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlowernamespace internal { 336804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower 34d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <typename T, bool conjugate> 356804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf, 366804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const int32 ndims, T* dst) { 376804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const int32* in_strides = buf; 386804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const int32* out_strides = buf + ndims; 396804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const int32* perm = buf + ndims * 2; 406804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower CUDA_1D_KERNEL_LOOP(o_idx, nthreads) { 416804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower int32 i_idx = 0; 426804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower int32 t = o_idx; 435a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower for (int32 i = 0; i < ndims; ++i) { 445a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower const int32 ratio = t / out_strides[i]; 455a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower t -= ratio * out_strides[i]; 465a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower i_idx += ratio * in_strides[perm[i]]; 476804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower } 48d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (conjugate) { 49d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower dst[o_idx] = Eigen::numext::conj(ldg(src + i_idx)); 50d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } else { 51d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower dst[o_idx] = ldg(src + i_idx); 52d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 536804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower } 546804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower} 556804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower 5647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <typename T, bool conjugate> 5747e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowervoid TransposeSimple(const GPUDevice& d, const Tensor& in, 586804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 596804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // Ensures we can use 32-bit index. 606804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const int64 nelem = in.NumElements(); 616804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU"; 626804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // Pack strides and permutation into one buffer. 636804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const int32 ndims = in.dims(); 6490d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai gtl::InlinedVector<int32, 24> host_buf(ndims * 3); 6590d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai gtl::InlinedVector<int32, 8> in_strides = ComputeStride<int32>(in.shape()); 6690d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai gtl::InlinedVector<int32, 8> out_strides = ComputeStride<int32>(out->shape()); 676804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // Dimension permutation. 686804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower for (int i = 0; i < ndims; ++i) { 6990d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai host_buf[i] = in_strides[i]; 7090d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai host_buf[ndims + i] = out_strides[i]; 716804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower host_buf[ndims * 2 + i] = perm[i]; 726804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower } 736804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // Copies the input strides, output strides and permutation to the device. 746804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower auto num_bytes = sizeof(int64) * host_buf.size(); 756804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower auto dev_buf = d.allocate(num_bytes); 766804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // NOTE: host_buf is not allocated by CudaHostAllocator, and 776804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // therefore we are doing a sync copy effectively. 786804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes); 796804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // Launch kernel to q[...] = p[...]. 806804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower const T* p = reinterpret_cast<const T*>(in.tensor_data().data()); 816804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data()))); 826804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); 83d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower TransposeKernel<T, conjugate> 84d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>( 85d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf), 86d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower ndims, q); 876804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower // Safe to deallocate immediately after the kernel launch. 886804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower d.deallocate(dev_buf); 896804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower} 906804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower 91ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// TransposeUsingTile tries to reduce the dimension of the input tensor to 3 and 92ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// then call special kernels to swap either dimension 1 and dimension 2 or 93ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// dimension 0 and dimension 2. It returns true if the operation is success, 94ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// false otherwise. 95d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <typename T, bool conjugate = false> 96d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct TransposeUsingTile { 97d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower static bool run(const Eigen::GpuDevice& d, const Tensor& in, 98d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 99d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower // First try to reduce the dimensions of the input tensor. 100d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower TransposePermsVec new_perm; 101d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower TransposeDimsVec new_dims; 102d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower ReduceTransposeDimensions(in.shape(), perm, &new_perm, &new_dims); 103d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower 104d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower // Only use special GPU kernel when dimension is 2 or 3. 105d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower int dims = new_dims.size(); 106d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (dims < 2 || dims > 3) return false; 107d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower auto in_data = reinterpret_cast<const T*>(in.tensor_data().data()); 108d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower auto out_data = 109d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())); 110d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower switch (dims) { 111d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower case 2: 112d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (new_perm[0] == 1 && new_perm[1] == 0) { 113d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower // Add the first dimension size as 1. 114d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower new_dims.insert(new_dims.begin(), 1); 1155a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T, 1165a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower conjugate>()( 117d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower d, in_data, new_dims, out_data); 118d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower return true; 119d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 120d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower break; 121d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower case 3: 122d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (new_perm == TransposePermsVec({0, 2, 1})) { 1235a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T, 1245a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower conjugate>()( 125d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower d, in_data, new_dims, out_data); 126d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower return true; 127d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } else if (new_perm == TransposePermsVec({2, 1, 0})) { 1285a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower tensorflow::functor::SwapDimension0And2InTensor3<GPUDevice, T, 1295a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower conjugate>()( 130d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower d, in_data, new_dims, out_data); 131d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower return true; 132d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } else { 133d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower // do not handle other 3D permutations 134d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower return false; 135d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 136d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower break; 137d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower default: 138ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang return false; 139d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 140d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower return false; 141d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 142d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower}; 143d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower 144d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <bool conjugate> 145d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct TransposeUsingTile<complex64, conjugate> { 146d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower static bool run(const Eigen::GpuDevice& d, const Tensor& in, 147d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 1485a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower if (!conjugate) { 1495a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower return TransposeUsingTile<uint64>::run(d, in, perm, out); 1505a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower } else { 1515a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower return TransposeUsingTile<float2, true>::run(d, in, perm, out); 152d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 153ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 154d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower}; 155d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower 156d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <bool conjugate> 157d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct TransposeUsingTile<complex128, conjugate> { 158d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower static bool run(const Eigen::GpuDevice& d, const Tensor& in, 159d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 1605a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower if (!conjugate) { 1615a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower return TransposeUsingTile<float4>::run(d, in, perm, out); 1625a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower } else { 1635a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower return TransposeUsingTile<double2, true>::run(d, in, perm, out); 164d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 165d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower } 166d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower}; 167ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang 16847e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower} // namespace internal 169ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang 17047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower// Transpose kernel specialized for GPU Device. 171d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <typename T, bool conjugate> 172d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct Transpose<GPUDevice, T, conjugate> { 173ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang static void run(const GPUDevice& d, const Tensor& in, 174ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang const gtl::ArraySlice<int32> perm, Tensor* out) { 175ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang switch (in.dims()) { 176ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang case 2: 177d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 178d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out)) { 179d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, conjugate, 180d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out); 181ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 182ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang break; 183ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang case 3: 184d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 185d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out)) { 186d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, conjugate, 187d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out); 188ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 189ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang break; 190ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang case 4: 191d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 192d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out)) { 193d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, conjugate, 194d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out); 195ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 196ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang break; 197ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang case 5: 198d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 199d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out)) { 200d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, conjugate, 201d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower out); 202ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 203ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang break; 204592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov case 6: 205592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 206592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov out)) { 207592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov internal::TransposeUsingEigen<GPUDevice, T, 6>(d, in, perm, conjugate, 208592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov out); 209592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov } 210592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov break; 211592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov case 7: 212592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 213592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov out)) { 214592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov internal::TransposeUsingEigen<GPUDevice, T, 7>(d, in, perm, conjugate, 215592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov out); 216592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov } 217592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov break; 218592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov case 8: 219592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, 220592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov out)) { 221592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov internal::TransposeUsingEigen<GPUDevice, T, 8>(d, in, perm, conjugate, 222592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov out); 223592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov } 224592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov break; 225ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang default: 22647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower internal::TransposeSimple<T, conjugate>(d, in, perm, out); 227ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang break; 228ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 229ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang } 230ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang}; 2316804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower 23247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <bool conjugate> 23347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerstruct Transpose<GPUDevice, string, conjugate> { 234d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower static void run(const GPUDevice& d, const Tensor& in, 235d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 236d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU."; 237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 238d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower}; 239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 24047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower// Explicit instantiation. 24147e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate struct Transpose<GPUDevice, string, false>; 24247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower 24347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <> 24447e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoTranspose(const GPUDevice& device, const Tensor& in, 24547e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 24647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, out); 24747e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower} 24847e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <> 24947e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoConjugateTranspose(const GPUDevice& device, const Tensor& in, 25047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower const gtl::ArraySlice<int32> perm, Tensor* out) { 25147e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, out); 25247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower} 25347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <> 25447e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoMatrixTranspose(const GPUDevice& device, const Tensor& in, 25547e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower Tensor* out) { 25647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, out); 25747e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower} 25847e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <> 25947e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoConjugateMatrixTranspose(const GPUDevice& device, const Tensor& in, 26047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower Tensor* out) { 26147e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, out); 26247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower} 26347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower 2646804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower} // namespace tensorflow 265f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // GOOGLE_CUDA 266