1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSETRIANGULARSOLVER_H
11#define EIGEN_SPARSETRIANGULARSOLVER_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Lhs, typename Rhs, int Mode,
18  int UpLo = (Mode & Lower)
19           ? Lower
20           : (Mode & Upper)
21           ? Upper
22           : -1,
23  int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
24struct sparse_solve_triangular_selector;
25
26// forward substitution, row-major
27template<typename Lhs, typename Rhs, int Mode>
28struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
29{
30  typedef typename Rhs::Scalar Scalar;
31  static void run(const Lhs& lhs, Rhs& other)
32  {
33    for(int col=0 ; col<other.cols() ; ++col)
34    {
35      for(int i=0; i<lhs.rows(); ++i)
36      {
37        Scalar tmp = other.coeff(i,col);
38        Scalar lastVal(0);
39        int lastIndex = 0;
40        for(typename Lhs::InnerIterator it(lhs, i); it; ++it)
41        {
42          lastVal = it.value();
43          lastIndex = it.index();
44          if(lastIndex==i)
45            break;
46          tmp -= lastVal * other.coeff(lastIndex,col);
47        }
48        if (Mode & UnitDiag)
49          other.coeffRef(i,col) = tmp;
50        else
51        {
52          eigen_assert(lastIndex==i);
53          other.coeffRef(i,col) = tmp/lastVal;
54        }
55      }
56    }
57  }
58};
59
60// backward substitution, row-major
61template<typename Lhs, typename Rhs, int Mode>
62struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
63{
64  typedef typename Rhs::Scalar Scalar;
65  static void run(const Lhs& lhs, Rhs& other)
66  {
67    for(int col=0 ; col<other.cols() ; ++col)
68    {
69      for(int i=lhs.rows()-1 ; i>=0 ; --i)
70      {
71        Scalar tmp = other.coeff(i,col);
72        Scalar l_ii = 0;
73        typename Lhs::InnerIterator it(lhs, i);
74        while(it && it.index()<i)
75          ++it;
76        if(!(Mode & UnitDiag))
77        {
78          eigen_assert(it && it.index()==i);
79          l_ii = it.value();
80          ++it;
81        }
82        else if (it && it.index() == i)
83          ++it;
84        for(; it; ++it)
85        {
86          tmp -= it.value() * other.coeff(it.index(),col);
87        }
88
89        if (Mode & UnitDiag)
90          other.coeffRef(i,col) = tmp;
91        else
92          other.coeffRef(i,col) = tmp/l_ii;
93      }
94    }
95  }
96};
97
98// forward substitution, col-major
99template<typename Lhs, typename Rhs, int Mode>
100struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
101{
102  typedef typename Rhs::Scalar Scalar;
103  static void run(const Lhs& lhs, Rhs& other)
104  {
105    for(int col=0 ; col<other.cols() ; ++col)
106    {
107      for(int i=0; i<lhs.cols(); ++i)
108      {
109        Scalar& tmp = other.coeffRef(i,col);
110        if (tmp!=Scalar(0)) // optimization when other is actually sparse
111        {
112          typename Lhs::InnerIterator it(lhs, i);
113          while(it && it.index()<i)
114            ++it;
115          if(!(Mode & UnitDiag))
116          {
117            eigen_assert(it && it.index()==i);
118            tmp /= it.value();
119          }
120          if (it && it.index()==i)
121            ++it;
122          for(; it; ++it)
123            other.coeffRef(it.index(), col) -= tmp * it.value();
124        }
125      }
126    }
127  }
128};
129
130// backward substitution, col-major
131template<typename Lhs, typename Rhs, int Mode>
132struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
133{
134  typedef typename Rhs::Scalar Scalar;
135  static void run(const Lhs& lhs, Rhs& other)
136  {
137    for(int col=0 ; col<other.cols() ; ++col)
138    {
139      for(int i=lhs.cols()-1; i>=0; --i)
140      {
141        Scalar& tmp = other.coeffRef(i,col);
142        if (tmp!=Scalar(0)) // optimization when other is actually sparse
143        {
144          if(!(Mode & UnitDiag))
145          {
146            // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
147            typename Lhs::ReverseInnerIterator it(lhs, i);
148            while(it && it.index()!=i)
149              --it;
150            eigen_assert(it && it.index()==i);
151            other.coeffRef(i,col) /= it.value();
152          }
153          typename Lhs::InnerIterator it(lhs, i);
154          for(; it && it.index()<i; ++it)
155            other.coeffRef(it.index(), col) -= tmp * it.value();
156        }
157      }
158    }
159  }
160};
161
162} // end namespace internal
163
164template<typename ExpressionType,int Mode>
165template<typename OtherDerived>
166void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
167{
168  eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
169  eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
170
171  enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
172
173  typedef typename internal::conditional<copy,
174    typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
175  OtherCopy otherCopy(other.derived());
176
177  internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
178
179  if (copy)
180    other = otherCopy;
181}
182
183template<typename ExpressionType,int Mode>
184template<typename OtherDerived>
185typename internal::plain_matrix_type_column_major<OtherDerived>::type
186SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
187{
188  typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
189  solveInPlace(res);
190  return res;
191}
192
193// pure sparse path
194
195namespace internal {
196
197template<typename Lhs, typename Rhs, int Mode,
198  int UpLo = (Mode & Lower)
199           ? Lower
200           : (Mode & Upper)
201           ? Upper
202           : -1,
203  int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
204struct sparse_solve_triangular_sparse_selector;
205
206// forward substitution, col-major
207template<typename Lhs, typename Rhs, int Mode, int UpLo>
208struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
209{
210  typedef typename Rhs::Scalar Scalar;
211  typedef typename promote_index_type<typename traits<Lhs>::Index,
212                                         typename traits<Rhs>::Index>::type Index;
213  static void run(const Lhs& lhs, Rhs& other)
214  {
215    const bool IsLower = (UpLo==Lower);
216    AmbiVector<Scalar,Index> tempVector(other.rows()*2);
217    tempVector.setBounds(0,other.rows());
218
219    Rhs res(other.rows(), other.cols());
220    res.reserve(other.nonZeros());
221
222    for(int col=0 ; col<other.cols() ; ++col)
223    {
224      // FIXME estimate number of non zeros
225      tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/);
226      tempVector.setZero();
227      tempVector.restart();
228      for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt)
229      {
230        tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
231      }
232
233      for(int i=IsLower?0:lhs.cols()-1;
234          IsLower?i<lhs.cols():i>=0;
235          i+=IsLower?1:-1)
236      {
237        tempVector.restart();
238        Scalar& ci = tempVector.coeffRef(i);
239        if (ci!=Scalar(0))
240        {
241          // find
242          typename Lhs::InnerIterator it(lhs, i);
243          if(!(Mode & UnitDiag))
244          {
245            if (IsLower)
246            {
247              eigen_assert(it.index()==i);
248              ci /= it.value();
249            }
250            else
251              ci /= lhs.coeff(i,i);
252          }
253          tempVector.restart();
254          if (IsLower)
255          {
256            if (it.index()==i)
257              ++it;
258            for(; it; ++it)
259              tempVector.coeffRef(it.index()) -= ci * it.value();
260          }
261          else
262          {
263            for(; it && it.index()<i; ++it)
264              tempVector.coeffRef(it.index()) -= ci * it.value();
265          }
266        }
267      }
268
269
270      int count = 0;
271      // FIXME compute a reference value to filter zeros
272      for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector/*,1e-12*/); it; ++it)
273      {
274        ++ count;
275//         std::cerr << "fill " << it.index() << ", " << col << "\n";
276//         std::cout << it.value() << "  ";
277        // FIXME use insertBack
278        res.insert(it.index(), col) = it.value();
279      }
280//       std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
281    }
282    res.finalize();
283    other = res.markAsRValue();
284  }
285};
286
287} // end namespace internal
288
289template<typename ExpressionType,int Mode>
290template<typename OtherDerived>
291void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
292{
293  eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
294  eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
295
296//   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
297
298//   typedef typename internal::conditional<copy,
299//     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
300//   OtherCopy otherCopy(other.derived());
301
302  internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
303
304//   if (copy)
305//     other = otherCopy;
306}
307
308#ifdef EIGEN2_SUPPORT
309
310// deprecated stuff:
311
312/** \deprecated */
313template<typename Derived>
314template<typename OtherDerived>
315void SparseMatrixBase<Derived>::solveTriangularInPlace(MatrixBase<OtherDerived>& other) const
316{
317  this->template triangular<Flags&(Upper|Lower)>().solveInPlace(other);
318}
319
320/** \deprecated */
321template<typename Derived>
322template<typename OtherDerived>
323typename internal::plain_matrix_type_column_major<OtherDerived>::type
324SparseMatrixBase<Derived>::solveTriangular(const MatrixBase<OtherDerived>& other) const
325{
326  typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
327  derived().solveTriangularInPlace(res);
328  return res;
329}
330#endif // EIGEN2_SUPPORT
331
332} // end namespace Eigen
333
334#endif // EIGEN_SPARSETRIANGULARSOLVER_H
335