1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_MATRIX_SQUARE_ROOT
11#define EIGEN_MATRIX_SQUARE_ROOT
12
13namespace Eigen {
14
15/** \ingroup MatrixFunctions_Module
16  * \brief Class for computing matrix square roots of upper quasi-triangular matrices.
17  * \tparam  MatrixType  type of the argument of the matrix square root,
18  *                      expected to be an instantiation of the Matrix class template.
19  *
20  * This class computes the square root of the upper quasi-triangular
21  * matrix stored in the upper Hessenberg part of the matrix passed to
22  * the constructor.
23  *
24  * \sa MatrixSquareRoot, MatrixSquareRootTriangular
25  */
26template <typename MatrixType>
27class MatrixSquareRootQuasiTriangular
28{
29  public:
30
31    /** \brief Constructor.
32      *
33      * \param[in]  A  upper quasi-triangular matrix whose square root
34      *                is to be computed.
35      *
36      * The class stores a reference to \p A, so it should not be
37      * changed (or destroyed) before compute() is called.
38      */
39    MatrixSquareRootQuasiTriangular(const MatrixType& A)
40      : m_A(A)
41    {
42      eigen_assert(A.rows() == A.cols());
43    }
44
45    /** \brief Compute the matrix square root
46      *
47      * \param[out] result  square root of \p A, as specified in the constructor.
48      *
49      * Only the upper Hessenberg part of \p result is updated, the
50      * rest is not touched.  See MatrixBase::sqrt() for details on
51      * how this computation is implemented.
52      */
53    template <typename ResultType> void compute(ResultType &result);
54
55  private:
56    typedef typename MatrixType::Index Index;
57    typedef typename MatrixType::Scalar Scalar;
58
59    void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
60    void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
61    void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
62    void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
63				  typename MatrixType::Index i, typename MatrixType::Index j);
64    void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
65				  typename MatrixType::Index i, typename MatrixType::Index j);
66    void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
67				  typename MatrixType::Index i, typename MatrixType::Index j);
68    void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
69				  typename MatrixType::Index i, typename MatrixType::Index j);
70
71    template <typename SmallMatrixType>
72    static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
73				     const SmallMatrixType& B, const SmallMatrixType& C);
74
75    const MatrixType& m_A;
76};
77
78template <typename MatrixType>
79template <typename ResultType>
80void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
81{
82  result.resize(m_A.rows(), m_A.cols());
83  computeDiagonalPartOfSqrt(result, m_A);
84  computeOffDiagonalPartOfSqrt(result, m_A);
85}
86
87// pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
88// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
89template <typename MatrixType>
90void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
91									  const MatrixType& T)
92{
93  using std::sqrt;
94  const Index size = m_A.rows();
95  for (Index i = 0; i < size; i++) {
96    if (i == size - 1 || T.coeff(i+1, i) == 0) {
97      eigen_assert(T(i,i) >= 0);
98      sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
99    }
100    else {
101      compute2x2diagonalBlock(sqrtT, T, i);
102      ++i;
103    }
104  }
105}
106
107// pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
108// post: sqrtT is the square root of T.
109template <typename MatrixType>
110void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
111									     const MatrixType& T)
112{
113  const Index size = m_A.rows();
114  for (Index j = 1; j < size; j++) {
115      if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
116	continue;
117    for (Index i = j-1; i >= 0; i--) {
118      if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
119	continue;
120      bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
121      bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
122      if (iBlockIs2x2 && jBlockIs2x2)
123	compute2x2offDiagonalBlock(sqrtT, T, i, j);
124      else if (iBlockIs2x2 && !jBlockIs2x2)
125	compute2x1offDiagonalBlock(sqrtT, T, i, j);
126      else if (!iBlockIs2x2 && jBlockIs2x2)
127	compute1x2offDiagonalBlock(sqrtT, T, i, j);
128      else if (!iBlockIs2x2 && !jBlockIs2x2)
129	compute1x1offDiagonalBlock(sqrtT, T, i, j);
130    }
131  }
132}
133
134// pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
135// post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
136template <typename MatrixType>
137void MatrixSquareRootQuasiTriangular<MatrixType>
138     ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
139{
140  // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
141  //       in EigenSolver. If we expose it, we could call it directly from here.
142  Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
143  EigenSolver<Matrix<Scalar,2,2> > es(block);
144  sqrtT.template block<2,2>(i,i)
145    = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
146}
147
148// pre:  block structure of T is such that (i,j) is a 1x1 block,
149//       all blocks of sqrtT to left of and below (i,j) are correct
150// post: sqrtT(i,j) has the correct value
151template <typename MatrixType>
152void MatrixSquareRootQuasiTriangular<MatrixType>
153     ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
154				  typename MatrixType::Index i, typename MatrixType::Index j)
155{
156  Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
157  sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
158}
159
160// similar to compute1x1offDiagonalBlock()
161template <typename MatrixType>
162void MatrixSquareRootQuasiTriangular<MatrixType>
163     ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
164				  typename MatrixType::Index i, typename MatrixType::Index j)
165{
166  Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
167  if (j-i > 1)
168    rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
169  Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
170  A += sqrtT.template block<2,2>(j,j).transpose();
171  sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
172}
173
174// similar to compute1x1offDiagonalBlock()
175template <typename MatrixType>
176void MatrixSquareRootQuasiTriangular<MatrixType>
177     ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
178				  typename MatrixType::Index i, typename MatrixType::Index j)
179{
180  Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
181  if (j-i > 2)
182    rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
183  Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
184  A += sqrtT.template block<2,2>(i,i);
185  sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
186}
187
188// similar to compute1x1offDiagonalBlock()
189template <typename MatrixType>
190void MatrixSquareRootQuasiTriangular<MatrixType>
191     ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
192				  typename MatrixType::Index i, typename MatrixType::Index j)
193{
194  Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
195  Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
196  Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
197  if (j-i > 2)
198    C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
199  Matrix<Scalar,2,2> X;
200  solveAuxiliaryEquation(X, A, B, C);
201  sqrtT.template block<2,2>(i,j) = X;
202}
203
204// solves the equation A X + X B = C where all matrices are 2-by-2
205template <typename MatrixType>
206template <typename SmallMatrixType>
207void MatrixSquareRootQuasiTriangular<MatrixType>
208     ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
209			      const SmallMatrixType& B, const SmallMatrixType& C)
210{
211  EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
212		      EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
213
214  Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
215  coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
216  coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
217  coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
218  coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
219  coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
220  coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
221  coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
222  coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
223  coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
224  coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
225  coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
226  coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
227
228  Matrix<Scalar,4,1> rhs;
229  rhs.coeffRef(0) = C.coeff(0,0);
230  rhs.coeffRef(1) = C.coeff(0,1);
231  rhs.coeffRef(2) = C.coeff(1,0);
232  rhs.coeffRef(3) = C.coeff(1,1);
233
234  Matrix<Scalar,4,1> result;
235  result = coeffMatrix.fullPivLu().solve(rhs);
236
237  X.coeffRef(0,0) = result.coeff(0);
238  X.coeffRef(0,1) = result.coeff(1);
239  X.coeffRef(1,0) = result.coeff(2);
240  X.coeffRef(1,1) = result.coeff(3);
241}
242
243
244/** \ingroup MatrixFunctions_Module
245  * \brief Class for computing matrix square roots of upper triangular matrices.
246  * \tparam  MatrixType  type of the argument of the matrix square root,
247  *                      expected to be an instantiation of the Matrix class template.
248  *
249  * This class computes the square root of the upper triangular matrix
250  * stored in the upper triangular part (including the diagonal) of
251  * the matrix passed to the constructor.
252  *
253  * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular
254  */
255template <typename MatrixType>
256class MatrixSquareRootTriangular
257{
258  public:
259    MatrixSquareRootTriangular(const MatrixType& A)
260      : m_A(A)
261    {
262      eigen_assert(A.rows() == A.cols());
263    }
264
265    /** \brief Compute the matrix square root
266      *
267      * \param[out] result  square root of \p A, as specified in the constructor.
268      *
269      * Only the upper triangular part (including the diagonal) of
270      * \p result is updated, the rest is not touched.  See
271      * MatrixBase::sqrt() for details on how this computation is
272      * implemented.
273      */
274    template <typename ResultType> void compute(ResultType &result);
275
276 private:
277    const MatrixType& m_A;
278};
279
280template <typename MatrixType>
281template <typename ResultType>
282void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
283{
284  using std::sqrt;
285
286  // Compute square root of m_A and store it in upper triangular part of result
287  // This uses that the square root of triangular matrices can be computed directly.
288  result.resize(m_A.rows(), m_A.cols());
289  typedef typename MatrixType::Index Index;
290  for (Index i = 0; i < m_A.rows(); i++) {
291    result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
292  }
293  for (Index j = 1; j < m_A.cols(); j++) {
294    for (Index i = j-1; i >= 0; i--) {
295      typedef typename MatrixType::Scalar Scalar;
296      // if i = j-1, then segment has length 0 so tmp = 0
297      Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
298      // denominator may be zero if original matrix is singular
299      result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
300    }
301  }
302}
303
304
305/** \ingroup MatrixFunctions_Module
306  * \brief Class for computing matrix square roots of general matrices.
307  * \tparam  MatrixType  type of the argument of the matrix square root,
308  *                      expected to be an instantiation of the Matrix class template.
309  *
310  * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt()
311  */
312template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
313class MatrixSquareRoot
314{
315  public:
316
317    /** \brief Constructor.
318      *
319      * \param[in]  A  matrix whose square root is to be computed.
320      *
321      * The class stores a reference to \p A, so it should not be
322      * changed (or destroyed) before compute() is called.
323      */
324    MatrixSquareRoot(const MatrixType& A);
325
326    /** \brief Compute the matrix square root
327      *
328      * \param[out] result  square root of \p A, as specified in the constructor.
329      *
330      * See MatrixBase::sqrt() for details on how this computation is
331      * implemented.
332      */
333    template <typename ResultType> void compute(ResultType &result);
334};
335
336
337// ********** Partial specialization for real matrices **********
338
339template <typename MatrixType>
340class MatrixSquareRoot<MatrixType, 0>
341{
342  public:
343
344    MatrixSquareRoot(const MatrixType& A)
345      : m_A(A)
346    {
347      eigen_assert(A.rows() == A.cols());
348    }
349
350    template <typename ResultType> void compute(ResultType &result)
351    {
352      // Compute Schur decomposition of m_A
353      const RealSchur<MatrixType> schurOfA(m_A);
354      const MatrixType& T = schurOfA.matrixT();
355      const MatrixType& U = schurOfA.matrixU();
356
357      // Compute square root of T
358      MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
359      MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
360
361      // Compute square root of m_A
362      result = U * sqrtT * U.adjoint();
363    }
364
365  private:
366    const MatrixType& m_A;
367};
368
369
370// ********** Partial specialization for complex matrices **********
371
372template <typename MatrixType>
373class MatrixSquareRoot<MatrixType, 1>
374{
375  public:
376
377    MatrixSquareRoot(const MatrixType& A)
378      : m_A(A)
379    {
380      eigen_assert(A.rows() == A.cols());
381    }
382
383    template <typename ResultType> void compute(ResultType &result)
384    {
385      // Compute Schur decomposition of m_A
386      const ComplexSchur<MatrixType> schurOfA(m_A);
387      const MatrixType& T = schurOfA.matrixT();
388      const MatrixType& U = schurOfA.matrixU();
389
390      // Compute square root of T
391      MatrixType sqrtT;
392      MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
393
394      // Compute square root of m_A
395      result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
396    }
397
398  private:
399    const MatrixType& m_A;
400};
401
402
403/** \ingroup MatrixFunctions_Module
404  *
405  * \brief Proxy for the matrix square root of some matrix (expression).
406  *
407  * \tparam Derived  Type of the argument to the matrix square root.
408  *
409  * This class holds the argument to the matrix square root until it
410  * is assigned or evaluated for some other reason (so the argument
411  * should not be changed in the meantime). It is the return type of
412  * MatrixBase::sqrt() and most of the time this is the only way it is
413  * used.
414  */
415template<typename Derived> class MatrixSquareRootReturnValue
416: public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
417{
418    typedef typename Derived::Index Index;
419  public:
420    /** \brief Constructor.
421      *
422      * \param[in]  src  %Matrix (expression) forming the argument of the
423      * matrix square root.
424      */
425    MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
426
427    /** \brief Compute the matrix square root.
428      *
429      * \param[out]  result  the matrix square root of \p src in the
430      * constructor.
431      */
432    template <typename ResultType>
433    inline void evalTo(ResultType& result) const
434    {
435      const typename Derived::PlainObject srcEvaluated = m_src.eval();
436      MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
437      me.compute(result);
438    }
439
440    Index rows() const { return m_src.rows(); }
441    Index cols() const { return m_src.cols(); }
442
443  protected:
444    const Derived& m_src;
445  private:
446    MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
447};
448
449namespace internal {
450template<typename Derived>
451struct traits<MatrixSquareRootReturnValue<Derived> >
452{
453  typedef typename Derived::PlainObject ReturnType;
454};
455}
456
457template <typename Derived>
458const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
459{
460  eigen_assert(rows() == cols());
461  return MatrixSquareRootReturnValue<Derived>(derived());
462}
463
464} // end namespace Eigen
465
466#endif // EIGEN_MATRIX_FUNCTION
467