1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SOLVEWITHGUESS_H
11#define EIGEN_SOLVEWITHGUESS_H
12
13namespace Eigen {
14
15template<typename Decomposition, typename RhsType, typename GuessType> class SolveWithGuess;
16
17/** \class SolveWithGuess
18  * \ingroup IterativeLinearSolvers_Module
19  *
20  * \brief Pseudo expression representing a solving operation
21  *
22  * \tparam Decomposition the type of the matrix or decomposion object
23  * \tparam Rhstype the type of the right-hand side
24  *
25  * This class represents an expression of A.solve(B)
26  * and most of the time this is the only way it is used.
27  *
28  */
29namespace internal {
30
31
32template<typename Decomposition, typename RhsType, typename GuessType>
33struct traits<SolveWithGuess<Decomposition, RhsType, GuessType> >
34  : traits<Solve<Decomposition,RhsType> >
35{};
36
37}
38
39
40template<typename Decomposition, typename RhsType, typename GuessType>
41class SolveWithGuess : public internal::generic_xpr_base<SolveWithGuess<Decomposition,RhsType,GuessType>, MatrixXpr, typename internal::traits<RhsType>::StorageKind>::type
42{
43public:
44  typedef typename internal::traits<SolveWithGuess>::Scalar Scalar;
45  typedef typename internal::traits<SolveWithGuess>::PlainObject PlainObject;
46  typedef typename internal::generic_xpr_base<SolveWithGuess<Decomposition,RhsType,GuessType>, MatrixXpr, typename internal::traits<RhsType>::StorageKind>::type Base;
47  typedef typename internal::ref_selector<SolveWithGuess>::type Nested;
48
49  SolveWithGuess(const Decomposition &dec, const RhsType &rhs, const GuessType &guess)
50    : m_dec(dec), m_rhs(rhs), m_guess(guess)
51  {}
52
53  EIGEN_DEVICE_FUNC Index rows() const { return m_dec.cols(); }
54  EIGEN_DEVICE_FUNC Index cols() const { return m_rhs.cols(); }
55
56  EIGEN_DEVICE_FUNC const Decomposition& dec()   const { return m_dec; }
57  EIGEN_DEVICE_FUNC const RhsType&       rhs()   const { return m_rhs; }
58  EIGEN_DEVICE_FUNC const GuessType&     guess() const { return m_guess; }
59
60protected:
61  const Decomposition &m_dec;
62  const RhsType       &m_rhs;
63  const GuessType     &m_guess;
64
65private:
66  Scalar coeff(Index row, Index col) const;
67  Scalar coeff(Index i) const;
68};
69
70namespace internal {
71
72// Evaluator of SolveWithGuess -> eval into a temporary
73template<typename Decomposition, typename RhsType, typename GuessType>
74struct evaluator<SolveWithGuess<Decomposition,RhsType, GuessType> >
75  : public evaluator<typename SolveWithGuess<Decomposition,RhsType,GuessType>::PlainObject>
76{
77  typedef SolveWithGuess<Decomposition,RhsType,GuessType> SolveType;
78  typedef typename SolveType::PlainObject PlainObject;
79  typedef evaluator<PlainObject> Base;
80
81  evaluator(const SolveType& solve)
82    : m_result(solve.rows(), solve.cols())
83  {
84    ::new (static_cast<Base*>(this)) Base(m_result);
85    m_result = solve.guess();
86    solve.dec()._solve_with_guess_impl(solve.rhs(), m_result);
87  }
88
89protected:
90  PlainObject m_result;
91};
92
93// Specialization for "dst = dec.solveWithGuess(rhs)"
94// NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse specialization must exist somewhere
95template<typename DstXprType, typename DecType, typename RhsType, typename GuessType, typename Scalar>
96struct Assignment<DstXprType, SolveWithGuess<DecType,RhsType,GuessType>, internal::assign_op<Scalar,Scalar>, Dense2Dense>
97{
98  typedef SolveWithGuess<DecType,RhsType,GuessType> SrcXprType;
99  static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
100  {
101    Index dstRows = src.rows();
102    Index dstCols = src.cols();
103    if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
104      dst.resize(dstRows, dstCols);
105
106    dst = src.guess();
107    src.dec()._solve_with_guess_impl(src.rhs(), dst/*, src.guess()*/);
108  }
109};
110
111} // end namepsace internal
112
113} // end namespace Eigen
114
115#endif // EIGEN_SOLVEWITHGUESS_H
116