1// This file is part of Eigen, a lightweight C++ template library 2// for linear algebra. 3// 4// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk> 5// 6// This Source Code Form is subject to the terms of the Mozilla 7// Public License v. 2.0. If a copy of the MPL was not distributed 8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10#ifndef EIGEN_MATRIX_SQUARE_ROOT 11#define EIGEN_MATRIX_SQUARE_ROOT 12 13namespace Eigen { 14 15/** \ingroup MatrixFunctions_Module 16 * \brief Class for computing matrix square roots of upper quasi-triangular matrices. 17 * \tparam MatrixType type of the argument of the matrix square root, 18 * expected to be an instantiation of the Matrix class template. 19 * 20 * This class computes the square root of the upper quasi-triangular 21 * matrix stored in the upper Hessenberg part of the matrix passed to 22 * the constructor. 23 * 24 * \sa MatrixSquareRoot, MatrixSquareRootTriangular 25 */ 26template <typename MatrixType> 27class MatrixSquareRootQuasiTriangular 28{ 29 public: 30 31 /** \brief Constructor. 32 * 33 * \param[in] A upper quasi-triangular matrix whose square root 34 * is to be computed. 35 * 36 * The class stores a reference to \p A, so it should not be 37 * changed (or destroyed) before compute() is called. 38 */ 39 MatrixSquareRootQuasiTriangular(const MatrixType& A) 40 : m_A(A) 41 { 42 eigen_assert(A.rows() == A.cols()); 43 } 44 45 /** \brief Compute the matrix square root 46 * 47 * \param[out] result square root of \p A, as specified in the constructor. 48 * 49 * Only the upper Hessenberg part of \p result is updated, the 50 * rest is not touched. See MatrixBase::sqrt() for details on 51 * how this computation is implemented. 52 */ 53 template <typename ResultType> void compute(ResultType &result); 54 55 private: 56 typedef typename MatrixType::Index Index; 57 typedef typename MatrixType::Scalar Scalar; 58 59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 61 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); 62 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 63 typename MatrixType::Index i, typename MatrixType::Index j); 64 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 65 typename MatrixType::Index i, typename MatrixType::Index j); 66 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 67 typename MatrixType::Index i, typename MatrixType::Index j); 68 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 69 typename MatrixType::Index i, typename MatrixType::Index j); 70 71 template <typename SmallMatrixType> 72 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 73 const SmallMatrixType& B, const SmallMatrixType& C); 74 75 const MatrixType& m_A; 76}; 77 78template <typename MatrixType> 79template <typename ResultType> 80void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) 81{ 82 result.resize(m_A.rows(), m_A.cols()); 83 computeDiagonalPartOfSqrt(result, m_A); 84 computeOffDiagonalPartOfSqrt(result, m_A); 85} 86 87// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size 88// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T 89template <typename MatrixType> 90void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, 91 const MatrixType& T) 92{ 93 using std::sqrt; 94 const Index size = m_A.rows(); 95 for (Index i = 0; i < size; i++) { 96 if (i == size - 1 || T.coeff(i+1, i) == 0) { 97 eigen_assert(T(i,i) >= 0); 98 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i)); 99 } 100 else { 101 compute2x2diagonalBlock(sqrtT, T, i); 102 ++i; 103 } 104 } 105} 106 107// pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. 108// post: sqrtT is the square root of T. 109template <typename MatrixType> 110void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, 111 const MatrixType& T) 112{ 113 const Index size = m_A.rows(); 114 for (Index j = 1; j < size; j++) { 115 if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block 116 continue; 117 for (Index i = j-1; i >= 0; i--) { 118 if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block 119 continue; 120 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); 121 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); 122 if (iBlockIs2x2 && jBlockIs2x2) 123 compute2x2offDiagonalBlock(sqrtT, T, i, j); 124 else if (iBlockIs2x2 && !jBlockIs2x2) 125 compute2x1offDiagonalBlock(sqrtT, T, i, j); 126 else if (!iBlockIs2x2 && jBlockIs2x2) 127 compute1x2offDiagonalBlock(sqrtT, T, i, j); 128 else if (!iBlockIs2x2 && !jBlockIs2x2) 129 compute1x1offDiagonalBlock(sqrtT, T, i, j); 130 } 131 } 132} 133 134// pre: T.block(i,i,2,2) has complex conjugate eigenvalues 135// post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) 136template <typename MatrixType> 137void MatrixSquareRootQuasiTriangular<MatrixType> 138 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) 139{ 140 // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere 141 // in EigenSolver. If we expose it, we could call it directly from here. 142 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i); 143 EigenSolver<Matrix<Scalar,2,2> > es(block); 144 sqrtT.template block<2,2>(i,i) 145 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real(); 146} 147 148// pre: block structure of T is such that (i,j) is a 1x1 block, 149// all blocks of sqrtT to left of and below (i,j) are correct 150// post: sqrtT(i,j) has the correct value 151template <typename MatrixType> 152void MatrixSquareRootQuasiTriangular<MatrixType> 153 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 154 typename MatrixType::Index i, typename MatrixType::Index j) 155{ 156 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); 157 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); 158} 159 160// similar to compute1x1offDiagonalBlock() 161template <typename MatrixType> 162void MatrixSquareRootQuasiTriangular<MatrixType> 163 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 164 typename MatrixType::Index i, typename MatrixType::Index j) 165{ 166 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); 167 if (j-i > 1) 168 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2); 169 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity(); 170 A += sqrtT.template block<2,2>(j,j).transpose(); 171 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose()); 172} 173 174// similar to compute1x1offDiagonalBlock() 175template <typename MatrixType> 176void MatrixSquareRootQuasiTriangular<MatrixType> 177 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 178 typename MatrixType::Index i, typename MatrixType::Index j) 179{ 180 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); 181 if (j-i > 2) 182 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1); 183 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity(); 184 A += sqrtT.template block<2,2>(i,i); 185 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs); 186} 187 188// similar to compute1x1offDiagonalBlock() 189template <typename MatrixType> 190void MatrixSquareRootQuasiTriangular<MatrixType> 191 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 192 typename MatrixType::Index i, typename MatrixType::Index j) 193{ 194 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); 195 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); 196 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j); 197 if (j-i > 2) 198 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2); 199 Matrix<Scalar,2,2> X; 200 solveAuxiliaryEquation(X, A, B, C); 201 sqrtT.template block<2,2>(i,j) = X; 202} 203 204// solves the equation A X + X B = C where all matrices are 2-by-2 205template <typename MatrixType> 206template <typename SmallMatrixType> 207void MatrixSquareRootQuasiTriangular<MatrixType> 208 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 209 const SmallMatrixType& B, const SmallMatrixType& C) 210{ 211 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value), 212 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); 213 214 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero(); 215 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0); 216 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1); 217 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0); 218 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1); 219 coeffMatrix.coeffRef(0,1) = B.coeff(1,0); 220 coeffMatrix.coeffRef(0,2) = A.coeff(0,1); 221 coeffMatrix.coeffRef(1,0) = B.coeff(0,1); 222 coeffMatrix.coeffRef(1,3) = A.coeff(0,1); 223 coeffMatrix.coeffRef(2,0) = A.coeff(1,0); 224 coeffMatrix.coeffRef(2,3) = B.coeff(1,0); 225 coeffMatrix.coeffRef(3,1) = A.coeff(1,0); 226 coeffMatrix.coeffRef(3,2) = B.coeff(0,1); 227 228 Matrix<Scalar,4,1> rhs; 229 rhs.coeffRef(0) = C.coeff(0,0); 230 rhs.coeffRef(1) = C.coeff(0,1); 231 rhs.coeffRef(2) = C.coeff(1,0); 232 rhs.coeffRef(3) = C.coeff(1,1); 233 234 Matrix<Scalar,4,1> result; 235 result = coeffMatrix.fullPivLu().solve(rhs); 236 237 X.coeffRef(0,0) = result.coeff(0); 238 X.coeffRef(0,1) = result.coeff(1); 239 X.coeffRef(1,0) = result.coeff(2); 240 X.coeffRef(1,1) = result.coeff(3); 241} 242 243 244/** \ingroup MatrixFunctions_Module 245 * \brief Class for computing matrix square roots of upper triangular matrices. 246 * \tparam MatrixType type of the argument of the matrix square root, 247 * expected to be an instantiation of the Matrix class template. 248 * 249 * This class computes the square root of the upper triangular matrix 250 * stored in the upper triangular part (including the diagonal) of 251 * the matrix passed to the constructor. 252 * 253 * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular 254 */ 255template <typename MatrixType> 256class MatrixSquareRootTriangular 257{ 258 public: 259 MatrixSquareRootTriangular(const MatrixType& A) 260 : m_A(A) 261 { 262 eigen_assert(A.rows() == A.cols()); 263 } 264 265 /** \brief Compute the matrix square root 266 * 267 * \param[out] result square root of \p A, as specified in the constructor. 268 * 269 * Only the upper triangular part (including the diagonal) of 270 * \p result is updated, the rest is not touched. See 271 * MatrixBase::sqrt() for details on how this computation is 272 * implemented. 273 */ 274 template <typename ResultType> void compute(ResultType &result); 275 276 private: 277 const MatrixType& m_A; 278}; 279 280template <typename MatrixType> 281template <typename ResultType> 282void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) 283{ 284 using std::sqrt; 285 286 // Compute square root of m_A and store it in upper triangular part of result 287 // This uses that the square root of triangular matrices can be computed directly. 288 result.resize(m_A.rows(), m_A.cols()); 289 typedef typename MatrixType::Index Index; 290 for (Index i = 0; i < m_A.rows(); i++) { 291 result.coeffRef(i,i) = sqrt(m_A.coeff(i,i)); 292 } 293 for (Index j = 1; j < m_A.cols(); j++) { 294 for (Index i = j-1; i >= 0; i--) { 295 typedef typename MatrixType::Scalar Scalar; 296 // if i = j-1, then segment has length 0 so tmp = 0 297 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); 298 // denominator may be zero if original matrix is singular 299 result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); 300 } 301 } 302} 303 304 305/** \ingroup MatrixFunctions_Module 306 * \brief Class for computing matrix square roots of general matrices. 307 * \tparam MatrixType type of the argument of the matrix square root, 308 * expected to be an instantiation of the Matrix class template. 309 * 310 * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() 311 */ 312template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> 313class MatrixSquareRoot 314{ 315 public: 316 317 /** \brief Constructor. 318 * 319 * \param[in] A matrix whose square root is to be computed. 320 * 321 * The class stores a reference to \p A, so it should not be 322 * changed (or destroyed) before compute() is called. 323 */ 324 MatrixSquareRoot(const MatrixType& A); 325 326 /** \brief Compute the matrix square root 327 * 328 * \param[out] result square root of \p A, as specified in the constructor. 329 * 330 * See MatrixBase::sqrt() for details on how this computation is 331 * implemented. 332 */ 333 template <typename ResultType> void compute(ResultType &result); 334}; 335 336 337// ********** Partial specialization for real matrices ********** 338 339template <typename MatrixType> 340class MatrixSquareRoot<MatrixType, 0> 341{ 342 public: 343 344 MatrixSquareRoot(const MatrixType& A) 345 : m_A(A) 346 { 347 eigen_assert(A.rows() == A.cols()); 348 } 349 350 template <typename ResultType> void compute(ResultType &result) 351 { 352 // Compute Schur decomposition of m_A 353 const RealSchur<MatrixType> schurOfA(m_A); 354 const MatrixType& T = schurOfA.matrixT(); 355 const MatrixType& U = schurOfA.matrixU(); 356 357 // Compute square root of T 358 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols()); 359 MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT); 360 361 // Compute square root of m_A 362 result = U * sqrtT * U.adjoint(); 363 } 364 365 private: 366 const MatrixType& m_A; 367}; 368 369 370// ********** Partial specialization for complex matrices ********** 371 372template <typename MatrixType> 373class MatrixSquareRoot<MatrixType, 1> 374{ 375 public: 376 377 MatrixSquareRoot(const MatrixType& A) 378 : m_A(A) 379 { 380 eigen_assert(A.rows() == A.cols()); 381 } 382 383 template <typename ResultType> void compute(ResultType &result) 384 { 385 // Compute Schur decomposition of m_A 386 const ComplexSchur<MatrixType> schurOfA(m_A); 387 const MatrixType& T = schurOfA.matrixT(); 388 const MatrixType& U = schurOfA.matrixU(); 389 390 // Compute square root of T 391 MatrixType sqrtT; 392 MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); 393 394 // Compute square root of m_A 395 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); 396 } 397 398 private: 399 const MatrixType& m_A; 400}; 401 402 403/** \ingroup MatrixFunctions_Module 404 * 405 * \brief Proxy for the matrix square root of some matrix (expression). 406 * 407 * \tparam Derived Type of the argument to the matrix square root. 408 * 409 * This class holds the argument to the matrix square root until it 410 * is assigned or evaluated for some other reason (so the argument 411 * should not be changed in the meantime). It is the return type of 412 * MatrixBase::sqrt() and most of the time this is the only way it is 413 * used. 414 */ 415template<typename Derived> class MatrixSquareRootReturnValue 416: public ReturnByValue<MatrixSquareRootReturnValue<Derived> > 417{ 418 typedef typename Derived::Index Index; 419 public: 420 /** \brief Constructor. 421 * 422 * \param[in] src %Matrix (expression) forming the argument of the 423 * matrix square root. 424 */ 425 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { } 426 427 /** \brief Compute the matrix square root. 428 * 429 * \param[out] result the matrix square root of \p src in the 430 * constructor. 431 */ 432 template <typename ResultType> 433 inline void evalTo(ResultType& result) const 434 { 435 const typename Derived::PlainObject srcEvaluated = m_src.eval(); 436 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated); 437 me.compute(result); 438 } 439 440 Index rows() const { return m_src.rows(); } 441 Index cols() const { return m_src.cols(); } 442 443 protected: 444 const Derived& m_src; 445 private: 446 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&); 447}; 448 449namespace internal { 450template<typename Derived> 451struct traits<MatrixSquareRootReturnValue<Derived> > 452{ 453 typedef typename Derived::PlainObject ReturnType; 454}; 455} 456 457template <typename Derived> 458const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const 459{ 460 eigen_assert(rows() == cols()); 461 return MatrixSquareRootReturnValue<Derived>(derived()); 462} 463 464} // end namespace Eigen 465 466#endif // EIGEN_MATRIX_FUNCTION 467