TriangularMatrixVector.h revision c981c48f5bc9aefeffc0bcb0cc3934c2fae179dd
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11#define EIGEN_TRIANGULARMATRIXVECTOR_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
18struct triangular_matrix_vector_product;
19
20template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22{
23  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
24  enum {
25    IsLower = ((Mode&Lower)==Lower),
26    HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27    HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28  };
29  static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30                                     const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
31  {
32    static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
33    Index size = (std::min)(_rows,_cols);
34    Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
35    Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
36
37    typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
38    const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
39    typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
40
41    typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
42    const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
43    typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
44
45    typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
46    ResMap res(_res,rows);
47
48    for (Index pi=0; pi<size; pi+=PanelWidth)
49    {
50      Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
51      for (Index k=0; k<actualPanelWidth; ++k)
52      {
53        Index i = pi + k;
54        Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
55        Index r = IsLower ? actualPanelWidth-k : k+1;
56        if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
57          res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
58        if (HasUnitDiag)
59          res.coeffRef(i) += alpha * cjRhs.coeff(i);
60      }
61      Index r = IsLower ? rows - pi - actualPanelWidth : pi;
62      if (r>0)
63      {
64        Index s = IsLower ? pi+actualPanelWidth : 0;
65        general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
66            r, actualPanelWidth,
67            &lhs.coeffRef(s,pi), lhsStride,
68            &rhs.coeffRef(pi), rhsIncr,
69            &res.coeffRef(s), resIncr, alpha);
70      }
71    }
72    if((!IsLower) && cols>size)
73    {
74      general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
75          rows, cols-size,
76          &lhs.coeffRef(0,size), lhsStride,
77          &rhs.coeffRef(size), rhsIncr,
78          _res, resIncr, alpha);
79    }
80  }
81};
82
83template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
84struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
85{
86  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
87  enum {
88    IsLower = ((Mode&Lower)==Lower),
89    HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
90    HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
91  };
92  static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
93                  const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
94  {
95    static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
96    Index diagSize = (std::min)(_rows,_cols);
97    Index rows = IsLower ? _rows : diagSize;
98    Index cols = IsLower ? diagSize : _cols;
99
100    typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
101    const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
102    typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
103
104    typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
105    const RhsMap rhs(_rhs,cols);
106    typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
107
108    typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
109    ResMap res(_res,rows,InnerStride<>(resIncr));
110
111    for (Index pi=0; pi<diagSize; pi+=PanelWidth)
112    {
113      Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
114      for (Index k=0; k<actualPanelWidth; ++k)
115      {
116        Index i = pi + k;
117        Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
118        Index r = IsLower ? k+1 : actualPanelWidth-k;
119        if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
120          res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
121        if (HasUnitDiag)
122          res.coeffRef(i) += alpha * cjRhs.coeff(i);
123      }
124      Index r = IsLower ? pi : cols - pi - actualPanelWidth;
125      if (r>0)
126      {
127        Index s = IsLower ? 0 : pi + actualPanelWidth;
128        general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
129            actualPanelWidth, r,
130            &lhs.coeffRef(pi,s), lhsStride,
131            &rhs.coeffRef(s), rhsIncr,
132            &res.coeffRef(pi), resIncr, alpha);
133      }
134    }
135    if(IsLower && rows>diagSize)
136    {
137      general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
138            rows-diagSize, cols,
139            &lhs.coeffRef(diagSize,0), lhsStride,
140            &rhs.coeffRef(0), rhsIncr,
141            &res.coeffRef(diagSize), resIncr, alpha);
142    }
143  }
144};
145
146/***************************************************************************
147* Wrapper to product_triangular_vector
148***************************************************************************/
149
150template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
151struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
152 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
153{};
154
155template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
156struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
157 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
158{};
159
160
161template<int StorageOrder>
162struct trmv_selector;
163
164} // end namespace internal
165
166template<int Mode, typename Lhs, typename Rhs>
167struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
168  : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
169{
170  EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
171
172  TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
173
174  template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
175  {
176    eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
177
178    internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
179  }
180};
181
182template<int Mode, typename Lhs, typename Rhs>
183struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
184  : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
185{
186  EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
187
188  TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
189
190  template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
191  {
192    eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
193
194    typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
195    Transpose<Dest> dstT(dst);
196    internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
197      TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
198  }
199};
200
201namespace internal {
202
203// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
204
205template<> struct trmv_selector<ColMajor>
206{
207  template<int Mode, typename Lhs, typename Rhs, typename Dest>
208  static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
209  {
210    typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
211    typedef typename ProductType::Index Index;
212    typedef typename ProductType::LhsScalar   LhsScalar;
213    typedef typename ProductType::RhsScalar   RhsScalar;
214    typedef typename ProductType::Scalar      ResScalar;
215    typedef typename ProductType::RealScalar  RealScalar;
216    typedef typename ProductType::ActualLhsType ActualLhsType;
217    typedef typename ProductType::ActualRhsType ActualRhsType;
218    typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
219    typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
220    typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
221
222    typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
223    typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
224
225    ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
226                                  * RhsBlasTraits::extractScalarFactor(prod.rhs());
227
228    enum {
229      // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
230      // on, the other hand it is good for the cache to pack the vector anyways...
231      EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
232      ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
233      MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
234    };
235
236    gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
237
238    bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
239    bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
240
241    RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
242
243    ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
244                                                  evalToDest ? dest.data() : static_dest.data());
245
246    if(!evalToDest)
247    {
248      #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
249      int size = dest.size();
250      EIGEN_DENSE_STORAGE_CTOR_PLUGIN
251      #endif
252      if(!alphaIsCompatible)
253      {
254        MappedDest(actualDestPtr, dest.size()).setZero();
255        compatibleAlpha = RhsScalar(1);
256      }
257      else
258        MappedDest(actualDestPtr, dest.size()) = dest;
259    }
260
261    internal::triangular_matrix_vector_product
262      <Index,Mode,
263       LhsScalar, LhsBlasTraits::NeedToConjugate,
264       RhsScalar, RhsBlasTraits::NeedToConjugate,
265       ColMajor>
266      ::run(actualLhs.rows(),actualLhs.cols(),
267            actualLhs.data(),actualLhs.outerStride(),
268            actualRhs.data(),actualRhs.innerStride(),
269            actualDestPtr,1,compatibleAlpha);
270
271    if (!evalToDest)
272    {
273      if(!alphaIsCompatible)
274        dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
275      else
276        dest = MappedDest(actualDestPtr, dest.size());
277    }
278  }
279};
280
281template<> struct trmv_selector<RowMajor>
282{
283  template<int Mode, typename Lhs, typename Rhs, typename Dest>
284  static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
285  {
286    typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
287    typedef typename ProductType::LhsScalar LhsScalar;
288    typedef typename ProductType::RhsScalar RhsScalar;
289    typedef typename ProductType::Scalar    ResScalar;
290    typedef typename ProductType::Index Index;
291    typedef typename ProductType::ActualLhsType ActualLhsType;
292    typedef typename ProductType::ActualRhsType ActualRhsType;
293    typedef typename ProductType::_ActualRhsType _ActualRhsType;
294    typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
295    typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
296
297    typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
298    typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
299
300    ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
301                                  * RhsBlasTraits::extractScalarFactor(prod.rhs());
302
303    enum {
304      DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
305    };
306
307    gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
308
309    ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
310        DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
311
312    if(!DirectlyUseRhs)
313    {
314      #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
315      int size = actualRhs.size();
316      EIGEN_DENSE_STORAGE_CTOR_PLUGIN
317      #endif
318      Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
319    }
320
321    internal::triangular_matrix_vector_product
322      <Index,Mode,
323       LhsScalar, LhsBlasTraits::NeedToConjugate,
324       RhsScalar, RhsBlasTraits::NeedToConjugate,
325       RowMajor>
326      ::run(actualLhs.rows(),actualLhs.cols(),
327            actualLhs.data(),actualLhs.outerStride(),
328            actualRhsPtr,1,
329            dest.data(),dest.innerStride(),
330            actualAlpha);
331  }
332};
333
334} // end namespace internal
335
336} // end namespace Eigen
337
338#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
339