1e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 3e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFloweryou may not use this file except in compliance with the License. 5e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerYou may obtain a copy of the License at 6e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 7e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 9e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerSee the License for the specific language governing permissions and 13e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerlimitations under the License. 14e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower==============================================================================*/ 15e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 16e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower// See docs in ../ops/linalg_ops.cc. 17e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 18e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#if GOOGLE_CUDA 19e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 20e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include <numeric> 21e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include <type_traits> 22e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 23e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#define EIGEN_USE_GPU 24e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/framework/kernel_def_builder.h" 26e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h" 27e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h" 28e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/cast_op.h" 29e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/cuda_solvers.h" 30e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/cwise_ops.h" 31e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/transpose_functor.h" 32e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h" 33e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/platform/logging.h" 34e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/platform/types.h" 35e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 36e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowernamespace tensorflow { 37e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 38e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowertypedef Eigen::GpuDevice GPUDevice; 39e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 40e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowertemplate <class Scalar> 41e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerclass SelfAdjointEigV2OpGpu : public AsyncOpKernel { 42e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower public: 43e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower explicit SelfAdjointEigV2OpGpu(OpKernelConstruction* context) 44e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower : AsyncOpKernel(context) { 45e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_)); 46e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 47e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 48e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 49e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower const Tensor& input = context->input(0); 50e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower const int ndims = input.dims(); 51e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_ASYNC( 52e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower context, ndims >= 2, 53e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 54e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done); 55e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower const int64 n = input.dim_size(ndims - 1); 56e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_ASYNC( 57e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower context, input.dim_size(ndims - 2) == n, 58e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower errors::InvalidArgument("Input matrices must be squares, got", 59e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower input.dim_size(ndims - 2), " != ", n), 60e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done); 61e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower const int64 batch_size = 62e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower input.template flat_inner_dims<Scalar, 3>().dimension(0); 63e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 64e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // Allocate outputs. 65e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower Tensor* eigenvalues; 66e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower TensorShape eigenvalues_shape = input.shape(); 67e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower eigenvalues_shape.RemoveLastDims(1); 68e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_OK_ASYNC( 69e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower context, context->allocate_output(0, eigenvalues_shape, &eigenvalues), 70e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done); 71e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower Tensor* eigenvectors; 72e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower TensorShape eigenvectors_shape = 73e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower compute_v_ ? input.shape() : TensorShape({}); 74e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_OK_ASYNC( 75e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower context, context->allocate_output(1, eigenvectors_shape, &eigenvectors), 76e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done); 77e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 78e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower if (input.NumElements() == 0) { 79e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done(); 80e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower return; 81e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 82e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 83e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // Allocate workspace. 84ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower // TODO(rmlarsen): Convert to std::make_unique when available. 85ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 86e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower Tensor eigenvalues_real; 87e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 88e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower if (std::is_same<Scalar, RealScalar>::value) { 89e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower eigenvalues_real = *eigenvalues; 90e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } else { 91e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_OK_ASYNC( 92e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower context, 93ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower solver->allocate_scoped_tensor(DataTypeToEnum<RealScalar>::value, 94ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower eigenvalues_shape, &eigenvalues_real), 95e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done); 96e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 97e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 98e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower Tensor input_copy; 99e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_OK_ASYNC( 100e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower context, 101ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower solver->forward_input_or_allocate_scoped_tensor( 102e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy), 103e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower done); 104e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // For real symmetric matrices, row-major and column-major are the same. For 105e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // complex Hermitian, row-major and column-major differ by a conjugation, 106e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // which is still cheaper than a transpose. 107e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower const GPUDevice& device = context->eigen_device<GPUDevice>(); 108e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower if (!input.SharesBufferWith(input_copy)) { 109e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower if (Eigen::NumTraits<Scalar>::IsComplex) { 110e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 111e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower conj(device, input_copy.flat<Scalar>() /*out*/, 112e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower input.flat<Scalar>() /*in*/); 113e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } else { 114e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower device.memcpy(input_copy.flat<Scalar>().data(), 115e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower input.flat<Scalar>().data(), 116e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower input.NumElements() * sizeof(Scalar)); 117e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 118e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } else if (Eigen::NumTraits<Scalar>::IsComplex) { 119e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 120e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower conj(device, const_cast<Tensor*>(&input)->flat<Scalar>() /*out*/, 121e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower input.flat<Scalar>() /*in*/); 122e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 123e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 124e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // Compute eigen decomposition in-place in input_copy. 125e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower std::vector<DeviceLapackInfo> dev_info; 126ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "heevd")); 127e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower auto input_copy_reshaped = input_copy.flat_inner_dims<Scalar, 3>(); 128e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower auto eigenvalues_real_reshaped = 129e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower eigenvalues_real.flat_inner_dims<RealScalar, 2>(); 130e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower for (int batch = 0; batch < batch_size; ++batch) { 131ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower OP_REQUIRES_OK_ASYNC( 132ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower context, 133ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower solver->Heevd(compute_v_ ? CUSOLVER_EIG_MODE_VECTOR 134ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower : CUSOLVER_EIG_MODE_NOVECTOR, 135ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower CUBLAS_FILL_MODE_UPPER, n, 136ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower &input_copy_reshaped(batch, 0, 0), n, 137ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower &eigenvalues_real_reshaped(batch, 0), 138ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower dev_info.back().mutable_data() + batch), 139ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower done); 140e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 141e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 142e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower if (!std::is_same<Scalar, RealScalar>::value) { 143e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower functor::CastFunctor<GPUDevice, Scalar, RealScalar> cast; 144e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower cast(device, eigenvalues->flat<Scalar>(), 145e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower const_cast<const Tensor*>(&eigenvalues_real)->flat<RealScalar>()); 146e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 147e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 148e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower if (compute_v_) { 149e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // Transpose eigenvectors now stored in input_copy in column-major form to 150e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // output in row-major form. 151e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower OP_REQUIRES_OK_ASYNC( 15247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower context, DoMatrixTranspose(device, input_copy, eigenvectors), done); 153e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 154e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 155e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower // Asynchronously check return status from cuSolver kernels. 156ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 157ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower std::move(done)); 158e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower } 159e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 160e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower private: 161e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower bool compute_v_; 162e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 163e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower TF_DISALLOW_COPY_AND_ASSIGN(SelfAdjointEigV2OpGpu); 164e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower}; 165e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 166e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#define REGISTER(Scalar) \ 167e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower REGISTER_KERNEL_BUILDER( \ 168e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower Name("SelfAdjointEigV2").Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), \ 169e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower (SelfAdjointEigV2OpGpu<Scalar>)) 170e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 171e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(float); 172e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(double); 173e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(complex64); 174e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(complex128); 175e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 176e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#undef REGISTER 177e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 178e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower} // namespace tensorflow 179e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower 180e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#endif // GOOGLE_CUDA 181