1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5// Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_SPARSE_TRIANGULARVIEW_H
12#define EIGEN_SPARSE_TRIANGULARVIEW_H
13
14namespace Eigen {
15
16/** \ingroup SparseCore_Module
17  *
18  * \brief Base class for a triangular part in a \b sparse matrix
19  *
20  * This class is an abstract base class of class TriangularView, and objects of type TriangularViewImpl cannot be instantiated.
21  * It extends class TriangularView with additional methods which are available for sparse expressions only.
22  *
23  * \sa class TriangularView, SparseMatrixBase::triangularView()
24  */
25template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse>
26  : public SparseMatrixBase<TriangularView<MatrixType,Mode> >
27{
28    enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
29                    || ((Mode&Upper) &&  (MatrixType::Flags&RowMajorBit)),
30           SkipLast = !SkipFirst,
31           SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
32           HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
33    };
34
35    typedef TriangularView<MatrixType,Mode> TriangularViewType;
36
37  protected:
38    // dummy solve function to make TriangularView happy.
39    void solve() const;
40
41    typedef SparseMatrixBase<TriangularViewType> Base;
42  public:
43
44    EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType)
45
46    typedef typename MatrixType::Nested MatrixTypeNested;
47    typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
48    typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
49
50    template<typename RhsType, typename DstType>
51    EIGEN_DEVICE_FUNC
52    EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
53      if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
54        dst = rhs;
55      this->solveInPlace(dst);
56    }
57
58    /** Applies the inverse of \c *this to the dense vector or matrix \a other, "in-place" */
59    template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const;
60
61    /** Applies the inverse of \c *this to the sparse vector or matrix \a other, "in-place" */
62    template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const;
63
64};
65
66namespace internal {
67
68template<typename ArgType, unsigned int Mode>
69struct unary_evaluator<TriangularView<ArgType,Mode>, IteratorBased>
70 : evaluator_base<TriangularView<ArgType,Mode> >
71{
72  typedef TriangularView<ArgType,Mode> XprType;
73
74protected:
75
76  typedef typename XprType::Scalar Scalar;
77  typedef typename XprType::StorageIndex StorageIndex;
78  typedef typename evaluator<ArgType>::InnerIterator EvalIterator;
79
80  enum { SkipFirst = ((Mode&Lower) && !(ArgType::Flags&RowMajorBit))
81                    || ((Mode&Upper) &&  (ArgType::Flags&RowMajorBit)),
82         SkipLast = !SkipFirst,
83         SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
84         HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
85  };
86
87public:
88
89  enum {
90    CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
91    Flags = XprType::Flags
92  };
93
94  explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {}
95
96  inline Index nonZerosEstimate() const {
97    return m_argImpl.nonZerosEstimate();
98  }
99
100  class InnerIterator : public EvalIterator
101  {
102      typedef EvalIterator Base;
103    public:
104
105      EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer)
106        : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize())
107      {
108        if(SkipFirst)
109        {
110          while((*this) && ((HasUnitDiag||SkipDiag)  ? this->index()<=outer : this->index()<outer))
111            Base::operator++();
112          if(HasUnitDiag)
113            m_returnOne = m_containsDiag;
114        }
115        else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
116        {
117          if((!SkipFirst) && Base::operator bool())
118            Base::operator++();
119          m_returnOne = m_containsDiag;
120        }
121      }
122
123      EIGEN_STRONG_INLINE InnerIterator& operator++()
124      {
125        if(HasUnitDiag && m_returnOne)
126          m_returnOne = false;
127        else
128        {
129          Base::operator++();
130          if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer()))
131          {
132            if((!SkipFirst) && Base::operator bool())
133              Base::operator++();
134            m_returnOne = m_containsDiag;
135          }
136        }
137        return *this;
138      }
139
140      EIGEN_STRONG_INLINE operator bool() const
141      {
142        if(HasUnitDiag && m_returnOne)
143          return true;
144        if(SkipFirst) return  Base::operator bool();
145        else
146        {
147          if (SkipDiag) return (Base::operator bool() && this->index() < this->outer());
148          else return (Base::operator bool() && this->index() <= this->outer());
149        }
150      }
151
152//       inline Index row() const { return (ArgType::Flags&RowMajorBit ? Base::outer() : this->index()); }
153//       inline Index col() const { return (ArgType::Flags&RowMajorBit ? this->index() : Base::outer()); }
154      inline StorageIndex index() const
155      {
156        if(HasUnitDiag && m_returnOne)  return internal::convert_index<StorageIndex>(Base::outer());
157        else                            return Base::index();
158      }
159      inline Scalar value() const
160      {
161        if(HasUnitDiag && m_returnOne)  return Scalar(1);
162        else                            return Base::value();
163      }
164
165    protected:
166      bool m_returnOne;
167      bool m_containsDiag;
168    private:
169      Scalar& valueRef();
170  };
171
172protected:
173  evaluator<ArgType> m_argImpl;
174  const ArgType& m_arg;
175};
176
177} // end namespace internal
178
179template<typename Derived>
180template<int Mode>
181inline const TriangularView<const Derived, Mode>
182SparseMatrixBase<Derived>::triangularView() const
183{
184  return TriangularView<const Derived, Mode>(derived());
185}
186
187} // end namespace Eigen
188
189#endif // EIGEN_SPARSE_TRIANGULARVIEW_H
190