1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SOLVERBASE_H
11#define EIGEN_SOLVERBASE_H
12
13namespace Eigen {
14
15namespace internal {
16
17
18
19} // end namespace internal
20
21/** \class SolverBase
22  * \brief A base class for matrix decomposition and solvers
23  *
24  * \tparam Derived the actual type of the decomposition/solver.
25  *
26  * Any matrix decomposition inheriting this base class provide the following API:
27  *
28  * \code
29  * MatrixType A, b, x;
30  * DecompositionType dec(A);
31  * x = dec.solve(b);             // solve A   * x = b
32  * x = dec.transpose().solve(b); // solve A^T * x = b
33  * x = dec.adjoint().solve(b);   // solve A'  * x = b
34  * \endcode
35  *
36  * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors.
37  *
38  * \sa class PartialPivLU, class FullPivLU
39  */
40template<typename Derived>
41class SolverBase : public EigenBase<Derived>
42{
43  public:
44
45    typedef EigenBase<Derived> Base;
46    typedef typename internal::traits<Derived>::Scalar Scalar;
47    typedef Scalar CoeffReturnType;
48
49    enum {
50      RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
51      ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
52      SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
53                                                          internal::traits<Derived>::ColsAtCompileTime>::ret),
54      MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
55      MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
56      MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
57                                                             internal::traits<Derived>::MaxColsAtCompileTime>::ret),
58      IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
59                           || internal::traits<Derived>::MaxColsAtCompileTime == 1
60    };
61
62    /** Default constructor */
63    SolverBase()
64    {}
65
66    ~SolverBase()
67    {}
68
69    using Base::derived;
70
71    /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A.
72      */
73    template<typename Rhs>
74    inline const Solve<Derived, Rhs>
75    solve(const MatrixBase<Rhs>& b) const
76    {
77      eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
78      return Solve<Derived, Rhs>(derived(), b.derived());
79    }
80
81    /** \internal the return type of transpose() */
82    typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType;
83    /** \returns an expression of the transposed of the factored matrix.
84      *
85      * A typical usage is to solve for the transposed problem A^T x = b:
86      * \code x = dec.transpose().solve(b); \endcode
87      *
88      * \sa adjoint(), solve()
89      */
90    inline ConstTransposeReturnType transpose() const
91    {
92      return ConstTransposeReturnType(derived());
93    }
94
95    /** \internal the return type of adjoint() */
96    typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
97                        CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
98                        ConstTransposeReturnType
99                     >::type AdjointReturnType;
100    /** \returns an expression of the adjoint of the factored matrix
101      *
102      * A typical usage is to solve for the adjoint problem A' x = b:
103      * \code x = dec.adjoint().solve(b); \endcode
104      *
105      * For real scalar types, this function is equivalent to transpose().
106      *
107      * \sa transpose(), solve()
108      */
109    inline AdjointReturnType adjoint() const
110    {
111      return AdjointReturnType(derived().transpose());
112    }
113
114  protected:
115};
116
117namespace internal {
118
119template<typename Derived>
120struct generic_xpr_base<Derived, MatrixXpr, SolverStorage>
121{
122  typedef SolverBase<Derived> type;
123
124};
125
126} // end namespace internal
127
128} // end namespace Eigen
129
130#endif // EIGEN_SOLVERBASE_H
131