1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11#define EIGEN_TRIANGULARMATRIXVECTOR_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
18struct triangular_matrix_vector_product;
19
20template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22{
23  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
24  enum {
25    IsLower = ((Mode&Lower)==Lower),
26    HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27    HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28  };
29  static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30                                     const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
31};
32
33template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
34EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
35  ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
36        const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
37  {
38    static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
39    Index size = (std::min)(_rows,_cols);
40    Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
41    Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
42
43    typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
44    const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
45    typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
46
47    typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
48    const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
49    typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
50
51    typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
52    ResMap res(_res,rows);
53
54    for (Index pi=0; pi<size; pi+=PanelWidth)
55    {
56      Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
57      for (Index k=0; k<actualPanelWidth; ++k)
58      {
59        Index i = pi + k;
60        Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
61        Index r = IsLower ? actualPanelWidth-k : k+1;
62        if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
63          res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
64        if (HasUnitDiag)
65          res.coeffRef(i) += alpha * cjRhs.coeff(i);
66      }
67      Index r = IsLower ? rows - pi - actualPanelWidth : pi;
68      if (r>0)
69      {
70        Index s = IsLower ? pi+actualPanelWidth : 0;
71        general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
72            r, actualPanelWidth,
73            &lhs.coeffRef(s,pi), lhsStride,
74            &rhs.coeffRef(pi), rhsIncr,
75            &res.coeffRef(s), resIncr, alpha);
76      }
77    }
78    if((!IsLower) && cols>size)
79    {
80      general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
81          rows, cols-size,
82          &lhs.coeffRef(0,size), lhsStride,
83          &rhs.coeffRef(size), rhsIncr,
84          _res, resIncr, alpha);
85    }
86  }
87
88template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
89struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
90{
91  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
92  enum {
93    IsLower = ((Mode&Lower)==Lower),
94    HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
95    HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
96  };
97  static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
98                                    const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
99};
100
101template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
102EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
103  ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
104        const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
105  {
106    static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
107    Index diagSize = (std::min)(_rows,_cols);
108    Index rows = IsLower ? _rows : diagSize;
109    Index cols = IsLower ? diagSize : _cols;
110
111    typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
112    const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
113    typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
114
115    typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
116    const RhsMap rhs(_rhs,cols);
117    typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
118
119    typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
120    ResMap res(_res,rows,InnerStride<>(resIncr));
121
122    for (Index pi=0; pi<diagSize; pi+=PanelWidth)
123    {
124      Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
125      for (Index k=0; k<actualPanelWidth; ++k)
126      {
127        Index i = pi + k;
128        Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
129        Index r = IsLower ? k+1 : actualPanelWidth-k;
130        if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
131          res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
132        if (HasUnitDiag)
133          res.coeffRef(i) += alpha * cjRhs.coeff(i);
134      }
135      Index r = IsLower ? pi : cols - pi - actualPanelWidth;
136      if (r>0)
137      {
138        Index s = IsLower ? 0 : pi + actualPanelWidth;
139        general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
140            actualPanelWidth, r,
141            &lhs.coeffRef(pi,s), lhsStride,
142            &rhs.coeffRef(s), rhsIncr,
143            &res.coeffRef(pi), resIncr, alpha);
144      }
145    }
146    if(IsLower && rows>diagSize)
147    {
148      general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
149            rows-diagSize, cols,
150            &lhs.coeffRef(diagSize,0), lhsStride,
151            &rhs.coeffRef(0), rhsIncr,
152            &res.coeffRef(diagSize), resIncr, alpha);
153    }
154  }
155
156/***************************************************************************
157* Wrapper to product_triangular_vector
158***************************************************************************/
159
160template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
161struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
162 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
163{};
164
165template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
166struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
167 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
168{};
169
170
171template<int StorageOrder>
172struct trmv_selector;
173
174} // end namespace internal
175
176template<int Mode, typename Lhs, typename Rhs>
177struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
178  : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
179{
180  EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
181
182  TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
183
184  template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
185  {
186    eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
187
188    internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
189  }
190};
191
192template<int Mode, typename Lhs, typename Rhs>
193struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
194  : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
195{
196  EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
197
198  TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
199
200  template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
201  {
202    eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
203
204    typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
205    Transpose<Dest> dstT(dst);
206    internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
207      TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
208  }
209};
210
211namespace internal {
212
213// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
214
215template<> struct trmv_selector<ColMajor>
216{
217  template<int Mode, typename Lhs, typename Rhs, typename Dest>
218  static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
219  {
220    typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
221    typedef typename ProductType::Index Index;
222    typedef typename ProductType::LhsScalar   LhsScalar;
223    typedef typename ProductType::RhsScalar   RhsScalar;
224    typedef typename ProductType::Scalar      ResScalar;
225    typedef typename ProductType::RealScalar  RealScalar;
226    typedef typename ProductType::ActualLhsType ActualLhsType;
227    typedef typename ProductType::ActualRhsType ActualRhsType;
228    typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
229    typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
230    typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
231
232    typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
233    typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
234
235    ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
236                                  * RhsBlasTraits::extractScalarFactor(prod.rhs());
237
238    enum {
239      // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
240      // on, the other hand it is good for the cache to pack the vector anyways...
241      EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
242      ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
243      MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
244    };
245
246    gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
247
248    bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
249    bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
250
251    RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
252
253    ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
254                                                  evalToDest ? dest.data() : static_dest.data());
255
256    if(!evalToDest)
257    {
258      #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
259      Index size = dest.size();
260      EIGEN_DENSE_STORAGE_CTOR_PLUGIN
261      #endif
262      if(!alphaIsCompatible)
263      {
264        MappedDest(actualDestPtr, dest.size()).setZero();
265        compatibleAlpha = RhsScalar(1);
266      }
267      else
268        MappedDest(actualDestPtr, dest.size()) = dest;
269    }
270
271    internal::triangular_matrix_vector_product
272      <Index,Mode,
273       LhsScalar, LhsBlasTraits::NeedToConjugate,
274       RhsScalar, RhsBlasTraits::NeedToConjugate,
275       ColMajor>
276      ::run(actualLhs.rows(),actualLhs.cols(),
277            actualLhs.data(),actualLhs.outerStride(),
278            actualRhs.data(),actualRhs.innerStride(),
279            actualDestPtr,1,compatibleAlpha);
280
281    if (!evalToDest)
282    {
283      if(!alphaIsCompatible)
284        dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
285      else
286        dest = MappedDest(actualDestPtr, dest.size());
287    }
288  }
289};
290
291template<> struct trmv_selector<RowMajor>
292{
293  template<int Mode, typename Lhs, typename Rhs, typename Dest>
294  static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
295  {
296    typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
297    typedef typename ProductType::LhsScalar LhsScalar;
298    typedef typename ProductType::RhsScalar RhsScalar;
299    typedef typename ProductType::Scalar    ResScalar;
300    typedef typename ProductType::Index Index;
301    typedef typename ProductType::ActualLhsType ActualLhsType;
302    typedef typename ProductType::ActualRhsType ActualRhsType;
303    typedef typename ProductType::_ActualRhsType _ActualRhsType;
304    typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
305    typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
306
307    typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
308    typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
309
310    ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
311                                  * RhsBlasTraits::extractScalarFactor(prod.rhs());
312
313    enum {
314      DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
315    };
316
317    gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
318
319    ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
320        DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
321
322    if(!DirectlyUseRhs)
323    {
324      #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
325      int size = actualRhs.size();
326      EIGEN_DENSE_STORAGE_CTOR_PLUGIN
327      #endif
328      Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
329    }
330
331    internal::triangular_matrix_vector_product
332      <Index,Mode,
333       LhsScalar, LhsBlasTraits::NeedToConjugate,
334       RhsScalar, RhsBlasTraits::NeedToConjugate,
335       RowMajor>
336      ::run(actualLhs.rows(),actualLhs.cols(),
337            actualLhs.data(),actualLhs.outerStride(),
338            actualRhsPtr,1,
339            dest.data(),dest.innerStride(),
340            actualAlpha);
341  }
342};
343
344} // end namespace internal
345
346} // end namespace Eigen
347
348#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
349