1e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
3e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFloweryou may not use this file except in compliance with the License.
5e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerYou may obtain a copy of the License at
6e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
7e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
9e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerSee the License for the specific language governing permissions and
13e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerlimitations under the License.
14e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower==============================================================================*/
15e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
16e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower// See docs in ../ops/linalg_ops.cc.
17e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
18e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#if GOOGLE_CUDA
19e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
20e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include <numeric>
21e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include <type_traits>
22e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
23e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#define EIGEN_USE_GPU
24e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/framework/kernel_def_builder.h"
26e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h"
27e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h"
28e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/cast_op.h"
29e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/cuda_solvers.h"
30e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/cwise_ops.h"
31e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/kernels/transpose_functor.h"
32e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h"
33e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/platform/logging.h"
34e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#include "tensorflow/core/platform/types.h"
35e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
36e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowernamespace tensorflow {
37e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
38e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowertypedef Eigen::GpuDevice GPUDevice;
39e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
40e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowertemplate <class Scalar>
41e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerclass SelfAdjointEigV2OpGpu : public AsyncOpKernel {
42e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower public:
43e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  explicit SelfAdjointEigV2OpGpu(OpKernelConstruction* context)
44e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      : AsyncOpKernel(context) {
45e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_));
46e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  }
47e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
48e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
49e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    const Tensor& input = context->input(0);
50e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    const int ndims = input.dims();
51e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    OP_REQUIRES_ASYNC(
52e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        context, ndims >= 2,
53e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
54e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        done);
55e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    const int64 n = input.dim_size(ndims - 1);
56e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    OP_REQUIRES_ASYNC(
57e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        context, input.dim_size(ndims - 2) == n,
58e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        errors::InvalidArgument("Input matrices must be squares, got",
59e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower                                input.dim_size(ndims - 2), " != ", n),
60e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        done);
61e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    const int64 batch_size =
62e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        input.template flat_inner_dims<Scalar, 3>().dimension(0);
63e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
64e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // Allocate outputs.
65e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    Tensor* eigenvalues;
66e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    TensorShape eigenvalues_shape = input.shape();
67e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    eigenvalues_shape.RemoveLastDims(1);
68e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    OP_REQUIRES_OK_ASYNC(
69e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        context, context->allocate_output(0, eigenvalues_shape, &eigenvalues),
70e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        done);
71e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    Tensor* eigenvectors;
72e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    TensorShape eigenvectors_shape =
73e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        compute_v_ ? input.shape() : TensorShape({});
74e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    OP_REQUIRES_OK_ASYNC(
75e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        context, context->allocate_output(1, eigenvectors_shape, &eigenvectors),
76e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        done);
77e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
78e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    if (input.NumElements() == 0) {
79e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      done();
80e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      return;
81e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    }
82e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
83e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // Allocate workspace.
84ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower    // TODO(rmlarsen): Convert to std::make_unique when available.
85ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower    std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
86e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    Tensor eigenvalues_real;
87e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
88e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    if (std::is_same<Scalar, RealScalar>::value) {
89e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      eigenvalues_real = *eigenvalues;
90e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    } else {
91e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      OP_REQUIRES_OK_ASYNC(
92e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower          context,
93ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower          solver->allocate_scoped_tensor(DataTypeToEnum<RealScalar>::value,
94ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                                         eigenvalues_shape, &eigenvalues_real),
95e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower          done);
96e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    }
97e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
98e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    Tensor input_copy;
99e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    OP_REQUIRES_OK_ASYNC(
100e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        context,
101ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower        solver->forward_input_or_allocate_scoped_tensor(
102e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower            {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
103e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        done);
104e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // For real symmetric matrices, row-major and column-major are the same. For
105e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // complex Hermitian, row-major and column-major differ by a conjugation,
106e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // which is still cheaper than a transpose.
107e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    const GPUDevice& device = context->eigen_device<GPUDevice>();
108e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    if (!input.SharesBufferWith(input_copy)) {
109e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      if (Eigen::NumTraits<Scalar>::IsComplex) {
110e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj;
111e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        conj(device, input_copy.flat<Scalar>() /*out*/,
112e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower             input.flat<Scalar>() /*in*/);
113e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      } else {
114e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        device.memcpy(input_copy.flat<Scalar>().data(),
115e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower                      input.flat<Scalar>().data(),
116e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower                      input.NumElements() * sizeof(Scalar));
117e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      }
118e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    } else if (Eigen::NumTraits<Scalar>::IsComplex) {
119e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj;
120e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      conj(device, const_cast<Tensor*>(&input)->flat<Scalar>() /*out*/,
121e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower           input.flat<Scalar>() /*in*/);
122e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    }
123e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
124e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // Compute eigen decomposition in-place in input_copy.
125e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    std::vector<DeviceLapackInfo> dev_info;
126ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower    dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "heevd"));
127e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    auto input_copy_reshaped = input_copy.flat_inner_dims<Scalar, 3>();
128e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    auto eigenvalues_real_reshaped =
129e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower        eigenvalues_real.flat_inner_dims<RealScalar, 2>();
130e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    for (int batch = 0; batch < batch_size; ++batch) {
131ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      OP_REQUIRES_OK_ASYNC(
132ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower          context,
133ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower          solver->Heevd(compute_v_ ? CUSOLVER_EIG_MODE_VECTOR
134ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                                   : CUSOLVER_EIG_MODE_NOVECTOR,
135ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                        CUBLAS_FILL_MODE_UPPER, n,
136ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                        &input_copy_reshaped(batch, 0, 0), n,
137ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                        &eigenvalues_real_reshaped(batch, 0),
138ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                        dev_info.back().mutable_data() + batch),
139ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower          done);
140e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    }
141e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
142e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    if (!std::is_same<Scalar, RealScalar>::value) {
143e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      functor::CastFunctor<GPUDevice, Scalar, RealScalar> cast;
144e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      cast(device, eigenvalues->flat<Scalar>(),
145e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower           const_cast<const Tensor*>(&eigenvalues_real)->flat<RealScalar>());
146e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    }
147e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
148e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    if (compute_v_) {
149e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      // Transpose eigenvectors now stored in input_copy in column-major form to
150e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      // output in row-major form.
151e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      OP_REQUIRES_OK_ASYNC(
15247e4d4b6b5742350233a8fd83cd81269792ed286A. Unique TensorFlower          context, DoMatrixTranspose(device, input_copy, eigenvectors), done);
153e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    }
154e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
155e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower    // Asynchronously check return status from cuSolver kernels.
156ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower    CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
157ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower                                                    std::move(done));
158e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  }
159e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
160e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower private:
161e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  bool compute_v_;
162e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
163e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  TF_DISALLOW_COPY_AND_ASSIGN(SelfAdjointEigV2OpGpu);
164e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower};
165e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
166e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#define REGISTER(Scalar)                                                       \
167e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower  REGISTER_KERNEL_BUILDER(                                                     \
168e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      Name("SelfAdjointEigV2").Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), \
169e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower      (SelfAdjointEigV2OpGpu<Scalar>))
170e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
171e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(float);
172e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(double);
173e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(complex64);
174e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlowerREGISTER(complex128);
175e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
176e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#undef REGISTER
177e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
178e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower}  // namespace tensorflow
179e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower
180e3413de529c3f762885efd62932f76445ed22653A. Unique TensorFlower#endif  // GOOGLE_CUDA
181