1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#if GOOGLE_CUDA
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_GPU
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
206804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2190d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai#include "tensorflow/core/kernels/ops_util.h"
22d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower#include "tensorflow/core/kernels/transpose_functor.h"
236804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower#include "tensorflow/core/util/cuda_kernel_helper.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
25ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// TODO(yangzihao): Remove the dependency of conv_2d.h once we move all
26ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// GPU util functions and transpose kernels into separate files.
27ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang#include "tensorflow/core/kernels/conv_2d.h"
28ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang
29d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertypedef Eigen::GpuDevice GPUDevice;
30d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
326804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlowernamespace internal {
336804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower
34d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <typename T, bool conjugate>
356804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
366804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower                                const int32 ndims, T* dst) {
376804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  const int32* in_strides = buf;
386804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  const int32* out_strides = buf + ndims;
396804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  const int32* perm = buf + ndims * 2;
406804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
416804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower    int32 i_idx = 0;
426804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower    int32 t = o_idx;
435a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower    for (int32 i = 0; i < ndims; ++i) {
445a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      const int32 ratio = t / out_strides[i];
455a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      t -= ratio * out_strides[i];
465a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      i_idx += ratio * in_strides[perm[i]];
476804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower    }
48d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    if (conjugate) {
49d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower      dst[o_idx] = Eigen::numext::conj(ldg(src + i_idx));
50d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    } else {
51d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower      dst[o_idx] = ldg(src + i_idx);
52d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    }
536804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  }
546804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower}
556804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower
5647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <typename T, bool conjugate>
5747e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowervoid TransposeSimple(const GPUDevice& d, const Tensor& in,
586804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower                     const gtl::ArraySlice<int32> perm, Tensor* out) {
596804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // Ensures we can use 32-bit index.
606804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  const int64 nelem = in.NumElements();
616804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU";
626804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // Pack strides and permutation into one buffer.
636804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  const int32 ndims = in.dims();
6490d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai  gtl::InlinedVector<int32, 24> host_buf(ndims * 3);
6590d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai  gtl::InlinedVector<int32, 8> in_strides = ComputeStride<int32>(in.shape());
6690d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai  gtl::InlinedVector<int32, 8> out_strides = ComputeStride<int32>(out->shape());
676804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // Dimension permutation.
686804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  for (int i = 0; i < ndims; ++i) {
6990d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    host_buf[i] = in_strides[i];
7090d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    host_buf[ndims + i] = out_strides[i];
716804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower    host_buf[ndims * 2 + i] = perm[i];
726804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  }
736804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // Copies the input strides, output strides and permutation to the device.
746804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  auto num_bytes = sizeof(int64) * host_buf.size();
756804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  auto dev_buf = d.allocate(num_bytes);
766804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // NOTE: host_buf is not allocated by CudaHostAllocator, and
776804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // therefore we are doing a sync copy effectively.
786804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes);
796804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // Launch kernel to q[...] = p[...].
806804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
816804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
826804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
83d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  TransposeKernel<T, conjugate>
84d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower      <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
85d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf),
86d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          ndims, q);
876804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  // Safe to deallocate immediately after the kernel launch.
886804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower  d.deallocate(dev_buf);
896804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower}
906804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower
91ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// TransposeUsingTile tries to reduce the dimension of the input tensor to 3 and
92ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// then call special kernels to swap either dimension 1 and dimension 2 or
93ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// dimension 0 and dimension 2. It returns true if the operation is success,
94ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang// false otherwise.
95d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <typename T, bool conjugate = false>
96d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct TransposeUsingTile {
97d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  static bool run(const Eigen::GpuDevice& d, const Tensor& in,
98d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                  const gtl::ArraySlice<int32> perm, Tensor* out) {
99d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    // First try to reduce the dimensions of the input tensor.
100d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    TransposePermsVec new_perm;
101d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    TransposeDimsVec new_dims;
102d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    ReduceTransposeDimensions(in.shape(), perm, &new_perm, &new_dims);
103d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower
104d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    // Only use special GPU kernel when dimension is 2 or 3.
105d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    int dims = new_dims.size();
106d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    if (dims < 2 || dims > 3) return false;
107d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    auto in_data = reinterpret_cast<const T*>(in.tensor_data().data());
108d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    auto out_data =
109d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data()));
110d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    switch (dims) {
111d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower      case 2:
112d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        if (new_perm[0] == 1 && new_perm[1] == 0) {
113d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          // Add the first dimension size as 1.
114d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          new_dims.insert(new_dims.begin(), 1);
1155a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower          tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T,
1165a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower                                                           conjugate>()(
117d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower              d, in_data, new_dims, out_data);
118d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          return true;
119d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        }
120d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        break;
121d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower      case 3:
122d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        if (new_perm == TransposePermsVec({0, 2, 1})) {
1235a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower          tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T,
1245a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower                                                           conjugate>()(
125d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower              d, in_data, new_dims, out_data);
126d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          return true;
127d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        } else if (new_perm == TransposePermsVec({2, 1, 0})) {
1285a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower          tensorflow::functor::SwapDimension0And2InTensor3<GPUDevice, T,
1295a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower                                                           conjugate>()(
130d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower              d, in_data, new_dims, out_data);
131d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          return true;
132d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        } else {
133d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          // do not handle other 3D permutations
134d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          return false;
135d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        }
136d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        break;
137d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower      default:
138ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        return false;
139d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    }
140d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    return false;
141d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  }
142d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower};
143d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower
144d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <bool conjugate>
145d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct TransposeUsingTile<complex64, conjugate> {
146d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  static bool run(const Eigen::GpuDevice& d, const Tensor& in,
147d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                  const gtl::ArraySlice<int32> perm, Tensor* out) {
1485a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower    if (!conjugate) {
1495a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      return TransposeUsingTile<uint64>::run(d, in, perm, out);
1505a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower    } else {
1515a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      return TransposeUsingTile<float2, true>::run(d, in, perm, out);
152d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    }
153ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang  }
154d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower};
155d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower
156d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <bool conjugate>
157d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct TransposeUsingTile<complex128, conjugate> {
158d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  static bool run(const Eigen::GpuDevice& d, const Tensor& in,
159d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                  const gtl::ArraySlice<int32> perm, Tensor* out) {
1605a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower    if (!conjugate) {
1615a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      return TransposeUsingTile<float4>::run(d, in, perm, out);
1625a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower    } else {
1635a8c47079f664b280bb28eb34ce2c93534305cdaA. Unique TensorFlower      return TransposeUsingTile<double2, true>::run(d, in, perm, out);
164d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    }
165d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  }
166d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower};
167ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang
16847e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower}  // namespace internal
169ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang
17047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower// Transpose kernel specialized for GPU Device.
171d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowertemplate <typename T, bool conjugate>
172d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlowerstruct Transpose<GPUDevice, T, conjugate> {
173ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang  static void run(const GPUDevice& d, const Tensor& in,
174ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang                  const gtl::ArraySlice<int32> perm, Tensor* out) {
175ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang    switch (in.dims()) {
176ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang      case 2:
177d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
178d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                             out)) {
179d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, conjugate,
180d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                         out);
181ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        }
182ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        break;
183ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang      case 3:
184d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
185d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                             out)) {
186d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, conjugate,
187d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                         out);
188ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        }
189ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        break;
190ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang      case 4:
191d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
192d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                             out)) {
193d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, conjugate,
194d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                         out);
195ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        }
196ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        break;
197ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang      case 5:
198d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
199d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                             out)) {
200d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower          internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, conjugate,
201d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                                                         out);
202ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        }
203ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        break;
204592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov      case 6:
205592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
206592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov                                                             out)) {
207592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov          internal::TransposeUsingEigen<GPUDevice, T, 6>(d, in, perm, conjugate,
208592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov                                                         out);
209592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        }
210592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        break;
211592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov      case 7:
212592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
213592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov                                                             out)) {
214592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov          internal::TransposeUsingEigen<GPUDevice, T, 7>(d, in, perm, conjugate,
215592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov                                                         out);
216592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        }
217592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        break;
218592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov      case 8:
219592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
220592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov                                                             out)) {
221592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov          internal::TransposeUsingEigen<GPUDevice, T, 8>(d, in, perm, conjugate,
222592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov                                                         out);
223592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        }
224592d2d67daca18db98c7f67b0a55ef487ed76f1cValentin Khrulkov        break;
225ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang      default:
22647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower        internal::TransposeSimple<T, conjugate>(d, in, perm, out);
227ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang        break;
228ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang    }
229ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang  }
230ec2f8761168c40a76b95220221889b47f82700d9Yangzihao Wang};
2316804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower
23247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <bool conjugate>
23347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerstruct Transpose<GPUDevice, string, conjugate> {
234d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower  static void run(const GPUDevice& d, const Tensor& in,
235d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower                  const gtl::ArraySlice<int32> perm, Tensor* out) {
236d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower    LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU.";
237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
238d835d677ade78a41e0e097f67c87b6ab8588a90aA. Unique TensorFlower};
239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
24047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower// Explicit instantiation.
24147e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate struct Transpose<GPUDevice, string, false>;
24247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower
24347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <>
24447e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoTranspose(const GPUDevice& device, const Tensor& in,
24547e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower                   const gtl::ArraySlice<int32> perm, Tensor* out) {
24647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower  return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, out);
24747e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower}
24847e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <>
24947e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoConjugateTranspose(const GPUDevice& device, const Tensor& in,
25047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower                            const gtl::ArraySlice<int32> perm, Tensor* out) {
25147e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower  return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, out);
25247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower}
25347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <>
25447e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoMatrixTranspose(const GPUDevice& device, const Tensor& in,
25547e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower                         Tensor* out) {
25647e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower  return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, out);
25747e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower}
25847e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowertemplate <>
25947e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlowerStatus DoConjugateMatrixTranspose(const GPUDevice& device, const Tensor& in,
26047e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower                                  Tensor* out) {
26147e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower  return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, out);
26247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower}
26347e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower
2646804c9cafc11fa73be3fdb057e033f0304661622A. Unique TensorFlower}  // namespace tensorflow
265f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // GOOGLE_CUDA
266