1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_MATRIX_SQUARE_ROOT
11#define EIGEN_MATRIX_SQUARE_ROOT
12
13namespace Eigen {
14
15/** \ingroup MatrixFunctions_Module
16  * \brief Class for computing matrix square roots of upper quasi-triangular matrices.
17  * \tparam  MatrixType  type of the argument of the matrix square root,
18  *                      expected to be an instantiation of the Matrix class template.
19  *
20  * This class computes the square root of the upper quasi-triangular
21  * matrix stored in the upper Hessenberg part of the matrix passed to
22  * the constructor.
23  *
24  * \sa MatrixSquareRoot, MatrixSquareRootTriangular
25  */
26template <typename MatrixType>
27class MatrixSquareRootQuasiTriangular
28{
29  public:
30
31    /** \brief Constructor.
32      *
33      * \param[in]  A  upper quasi-triangular matrix whose square root
34      *                is to be computed.
35      *
36      * The class stores a reference to \p A, so it should not be
37      * changed (or destroyed) before compute() is called.
38      */
39    MatrixSquareRootQuasiTriangular(const MatrixType& A)
40      : m_A(A)
41    {
42      eigen_assert(A.rows() == A.cols());
43    }
44
45    /** \brief Compute the matrix square root
46      *
47      * \param[out] result  square root of \p A, as specified in the constructor.
48      *
49      * Only the upper Hessenberg part of \p result is updated, the
50      * rest is not touched.  See MatrixBase::sqrt() for details on
51      * how this computation is implemented.
52      */
53    template <typename ResultType> void compute(ResultType &result);
54
55  private:
56    typedef typename MatrixType::Index Index;
57    typedef typename MatrixType::Scalar Scalar;
58
59    void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
60    void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
61    void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
62    void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
63  				  typename MatrixType::Index i, typename MatrixType::Index j);
64    void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
65  				  typename MatrixType::Index i, typename MatrixType::Index j);
66    void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
67  				  typename MatrixType::Index i, typename MatrixType::Index j);
68    void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
69  				  typename MatrixType::Index i, typename MatrixType::Index j);
70
71    template <typename SmallMatrixType>
72    static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
73  				     const SmallMatrixType& B, const SmallMatrixType& C);
74
75    const MatrixType& m_A;
76};
77
78template <typename MatrixType>
79template <typename ResultType>
80void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
81{
82  // Compute Schur decomposition of m_A
83  const RealSchur<MatrixType> schurOfA(m_A);
84  const MatrixType& T = schurOfA.matrixT();
85  const MatrixType& U = schurOfA.matrixU();
86
87  // Compute square root of T
88  MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
89  computeDiagonalPartOfSqrt(sqrtT, T);
90  computeOffDiagonalPartOfSqrt(sqrtT, T);
91
92  // Compute square root of m_A
93  result = U * sqrtT * U.adjoint();
94}
95
96// pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
97// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
98template <typename MatrixType>
99void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
100									  const MatrixType& T)
101{
102  const Index size = m_A.rows();
103  for (Index i = 0; i < size; i++) {
104    if (i == size - 1 || T.coeff(i+1, i) == 0) {
105      eigen_assert(T(i,i) > 0);
106      sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
107    }
108    else {
109      compute2x2diagonalBlock(sqrtT, T, i);
110      ++i;
111    }
112  }
113}
114
115// pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
116// post: sqrtT is the square root of T.
117template <typename MatrixType>
118void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
119									     const MatrixType& T)
120{
121  const Index size = m_A.rows();
122  for (Index j = 1; j < size; j++) {
123      if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
124	continue;
125    for (Index i = j-1; i >= 0; i--) {
126      if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
127	continue;
128      bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
129      bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
130      if (iBlockIs2x2 && jBlockIs2x2)
131	compute2x2offDiagonalBlock(sqrtT, T, i, j);
132      else if (iBlockIs2x2 && !jBlockIs2x2)
133	compute2x1offDiagonalBlock(sqrtT, T, i, j);
134      else if (!iBlockIs2x2 && jBlockIs2x2)
135	compute1x2offDiagonalBlock(sqrtT, T, i, j);
136      else if (!iBlockIs2x2 && !jBlockIs2x2)
137	compute1x1offDiagonalBlock(sqrtT, T, i, j);
138    }
139  }
140}
141
142// pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
143// post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
144template <typename MatrixType>
145void MatrixSquareRootQuasiTriangular<MatrixType>
146     ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
147{
148  // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
149  //       in EigenSolver. If we expose it, we could call it directly from here.
150  Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
151  EigenSolver<Matrix<Scalar,2,2> > es(block);
152  sqrtT.template block<2,2>(i,i)
153    = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
154}
155
156// pre:  block structure of T is such that (i,j) is a 1x1 block,
157//       all blocks of sqrtT to left of and below (i,j) are correct
158// post: sqrtT(i,j) has the correct value
159template <typename MatrixType>
160void MatrixSquareRootQuasiTriangular<MatrixType>
161     ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
162				  typename MatrixType::Index i, typename MatrixType::Index j)
163{
164  Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
165  sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
166}
167
168// similar to compute1x1offDiagonalBlock()
169template <typename MatrixType>
170void MatrixSquareRootQuasiTriangular<MatrixType>
171     ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
172				  typename MatrixType::Index i, typename MatrixType::Index j)
173{
174  Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
175  if (j-i > 1)
176    rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
177  Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
178  A += sqrtT.template block<2,2>(j,j).transpose();
179  sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
180}
181
182// similar to compute1x1offDiagonalBlock()
183template <typename MatrixType>
184void MatrixSquareRootQuasiTriangular<MatrixType>
185     ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
186				  typename MatrixType::Index i, typename MatrixType::Index j)
187{
188  Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
189  if (j-i > 2)
190    rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
191  Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
192  A += sqrtT.template block<2,2>(i,i);
193  sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
194}
195
196// similar to compute1x1offDiagonalBlock()
197template <typename MatrixType>
198void MatrixSquareRootQuasiTriangular<MatrixType>
199     ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
200				  typename MatrixType::Index i, typename MatrixType::Index j)
201{
202  Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
203  Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
204  Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
205  if (j-i > 2)
206    C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
207  Matrix<Scalar,2,2> X;
208  solveAuxiliaryEquation(X, A, B, C);
209  sqrtT.template block<2,2>(i,j) = X;
210}
211
212// solves the equation A X + X B = C where all matrices are 2-by-2
213template <typename MatrixType>
214template <typename SmallMatrixType>
215void MatrixSquareRootQuasiTriangular<MatrixType>
216     ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
217			      const SmallMatrixType& B, const SmallMatrixType& C)
218{
219  EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
220		      EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
221
222  Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
223  coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
224  coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
225  coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
226  coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
227  coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
228  coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
229  coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
230  coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
231  coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
232  coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
233  coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
234  coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
235
236  Matrix<Scalar,4,1> rhs;
237  rhs.coeffRef(0) = C.coeff(0,0);
238  rhs.coeffRef(1) = C.coeff(0,1);
239  rhs.coeffRef(2) = C.coeff(1,0);
240  rhs.coeffRef(3) = C.coeff(1,1);
241
242  Matrix<Scalar,4,1> result;
243  result = coeffMatrix.fullPivLu().solve(rhs);
244
245  X.coeffRef(0,0) = result.coeff(0);
246  X.coeffRef(0,1) = result.coeff(1);
247  X.coeffRef(1,0) = result.coeff(2);
248  X.coeffRef(1,1) = result.coeff(3);
249}
250
251
252/** \ingroup MatrixFunctions_Module
253  * \brief Class for computing matrix square roots of upper triangular matrices.
254  * \tparam  MatrixType  type of the argument of the matrix square root,
255  *                      expected to be an instantiation of the Matrix class template.
256  *
257  * This class computes the square root of the upper triangular matrix
258  * stored in the upper triangular part (including the diagonal) of
259  * the matrix passed to the constructor.
260  *
261  * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular
262  */
263template <typename MatrixType>
264class MatrixSquareRootTriangular
265{
266  public:
267    MatrixSquareRootTriangular(const MatrixType& A)
268      : m_A(A)
269    {
270      eigen_assert(A.rows() == A.cols());
271    }
272
273    /** \brief Compute the matrix square root
274      *
275      * \param[out] result  square root of \p A, as specified in the constructor.
276      *
277      * Only the upper triangular part (including the diagonal) of
278      * \p result is updated, the rest is not touched.  See
279      * MatrixBase::sqrt() for details on how this computation is
280      * implemented.
281      */
282    template <typename ResultType> void compute(ResultType &result);
283
284 private:
285    const MatrixType& m_A;
286};
287
288template <typename MatrixType>
289template <typename ResultType>
290void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
291{
292  // Compute Schur decomposition of m_A
293  const ComplexSchur<MatrixType> schurOfA(m_A);
294  const MatrixType& T = schurOfA.matrixT();
295  const MatrixType& U = schurOfA.matrixU();
296
297  // Compute square root of T and store it in upper triangular part of result
298  // This uses that the square root of triangular matrices can be computed directly.
299  result.resize(m_A.rows(), m_A.cols());
300  typedef typename MatrixType::Index Index;
301  for (Index i = 0; i < m_A.rows(); i++) {
302    result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
303  }
304  for (Index j = 1; j < m_A.cols(); j++) {
305    for (Index i = j-1; i >= 0; i--) {
306      typedef typename MatrixType::Scalar Scalar;
307      // if i = j-1, then segment has length 0 so tmp = 0
308      Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
309      // denominator may be zero if original matrix is singular
310      result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
311    }
312  }
313
314  // Compute square root of m_A as U * result * U.adjoint()
315  MatrixType tmp;
316  tmp.noalias() = U * result.template triangularView<Upper>();
317  result.noalias() = tmp * U.adjoint();
318}
319
320
321/** \ingroup MatrixFunctions_Module
322  * \brief Class for computing matrix square roots of general matrices.
323  * \tparam  MatrixType  type of the argument of the matrix square root,
324  *                      expected to be an instantiation of the Matrix class template.
325  *
326  * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt()
327  */
328template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
329class MatrixSquareRoot
330{
331  public:
332
333    /** \brief Constructor.
334      *
335      * \param[in]  A  matrix whose square root is to be computed.
336      *
337      * The class stores a reference to \p A, so it should not be
338      * changed (or destroyed) before compute() is called.
339      */
340    MatrixSquareRoot(const MatrixType& A);
341
342    /** \brief Compute the matrix square root
343      *
344      * \param[out] result  square root of \p A, as specified in the constructor.
345      *
346      * See MatrixBase::sqrt() for details on how this computation is
347      * implemented.
348      */
349    template <typename ResultType> void compute(ResultType &result);
350};
351
352
353// ********** Partial specialization for real matrices **********
354
355template <typename MatrixType>
356class MatrixSquareRoot<MatrixType, 0>
357{
358  public:
359
360    MatrixSquareRoot(const MatrixType& A)
361      : m_A(A)
362    {
363      eigen_assert(A.rows() == A.cols());
364    }
365
366    template <typename ResultType> void compute(ResultType &result)
367    {
368      // Compute Schur decomposition of m_A
369      const RealSchur<MatrixType> schurOfA(m_A);
370      const MatrixType& T = schurOfA.matrixT();
371      const MatrixType& U = schurOfA.matrixU();
372
373      // Compute square root of T
374      MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
375      MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
376      tmp.compute(sqrtT);
377
378      // Compute square root of m_A
379      result = U * sqrtT * U.adjoint();
380    }
381
382  private:
383    const MatrixType& m_A;
384};
385
386
387// ********** Partial specialization for complex matrices **********
388
389template <typename MatrixType>
390class MatrixSquareRoot<MatrixType, 1>
391{
392  public:
393
394    MatrixSquareRoot(const MatrixType& A)
395      : m_A(A)
396    {
397      eigen_assert(A.rows() == A.cols());
398    }
399
400    template <typename ResultType> void compute(ResultType &result)
401    {
402      // Compute Schur decomposition of m_A
403      const ComplexSchur<MatrixType> schurOfA(m_A);
404      const MatrixType& T = schurOfA.matrixT();
405      const MatrixType& U = schurOfA.matrixU();
406
407      // Compute square root of T
408      MatrixSquareRootTriangular<MatrixType> tmp(T);
409      MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
410      tmp.compute(sqrtT);
411
412      // Compute square root of m_A
413      result = U * sqrtT * U.adjoint();
414    }
415
416  private:
417    const MatrixType& m_A;
418};
419
420
421/** \ingroup MatrixFunctions_Module
422  *
423  * \brief Proxy for the matrix square root of some matrix (expression).
424  *
425  * \tparam Derived  Type of the argument to the matrix square root.
426  *
427  * This class holds the argument to the matrix square root until it
428  * is assigned or evaluated for some other reason (so the argument
429  * should not be changed in the meantime). It is the return type of
430  * MatrixBase::sqrt() and most of the time this is the only way it is
431  * used.
432  */
433template<typename Derived> class MatrixSquareRootReturnValue
434: public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
435{
436    typedef typename Derived::Index Index;
437  public:
438    /** \brief Constructor.
439      *
440      * \param[in]  src  %Matrix (expression) forming the argument of the
441      * matrix square root.
442      */
443    MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
444
445    /** \brief Compute the matrix square root.
446      *
447      * \param[out]  result  the matrix square root of \p src in the
448      * constructor.
449      */
450    template <typename ResultType>
451    inline void evalTo(ResultType& result) const
452    {
453      const typename Derived::PlainObject srcEvaluated = m_src.eval();
454      MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
455      me.compute(result);
456    }
457
458    Index rows() const { return m_src.rows(); }
459    Index cols() const { return m_src.cols(); }
460
461  protected:
462    const Derived& m_src;
463  private:
464    MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
465};
466
467namespace internal {
468template<typename Derived>
469struct traits<MatrixSquareRootReturnValue<Derived> >
470{
471  typedef typename Derived::PlainObject ReturnType;
472};
473}
474
475template <typename Derived>
476const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
477{
478  eigen_assert(rows() == cols());
479  return MatrixSquareRootReturnValue<Derived>(derived());
480}
481
482} // end namespace Eigen
483
484#endif // EIGEN_MATRIX_FUNCTION
485