1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
11#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
12
13namespace Eigen {
14
15template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
16struct selfadjoint_rank1_update;
17
18namespace internal {
19
20/**********************************************************************
21* This file implements a general A * B product while
22* evaluating only one triangular part of the product.
23* This is a more general version of self adjoint product (C += A A^T)
24* as the level 3 SYRK Blas routine.
25**********************************************************************/
26
27// forward declarations (defined at the end of this file)
28template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
29struct tribb_kernel;
30
31/* Optimized matrix-matrix product evaluating only one triangular half */
32template <typename Index,
33          typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
34          typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
35                              int ResStorageOrder, int  UpLo, int Version = Specialized>
36struct general_matrix_matrix_triangular_product;
37
38// as usual if the result is row major => we transpose the product
39template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
40                          typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int  UpLo, int Version>
41struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo,Version>
42{
43  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
44  static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
45                                      const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride,
46                                      const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking)
47  {
48    general_matrix_matrix_triangular_product<Index,
49        RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
50        LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
51        ColMajor, UpLo==Lower?Upper:Lower>
52      ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking);
53  }
54};
55
56template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
57                          typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int  UpLo, int Version>
58struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Version>
59{
60  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
61  static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
62                                      const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride,
63                                      const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
64  {
65    typedef gebp_traits<LhsScalar,RhsScalar> Traits;
66
67    typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
68    typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
69    typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
70    LhsMapper lhs(_lhs,lhsStride);
71    RhsMapper rhs(_rhs,rhsStride);
72    ResMapper res(_res, resStride);
73
74    Index kc = blocking.kc();
75    Index mc = (std::min)(size,blocking.mc());
76
77    // !!! mc must be a multiple of nr:
78    if(mc > Traits::nr)
79      mc = (mc/Traits::nr)*Traits::nr;
80
81    std::size_t sizeA = kc*mc;
82    std::size_t sizeB = kc*size;
83
84    ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
85    ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
86
87    gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
88    gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
89    gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
90    tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb;
91
92    for(Index k2=0; k2<depth; k2+=kc)
93    {
94      const Index actual_kc = (std::min)(k2+kc,depth)-k2;
95
96      // note that the actual rhs is the transpose/adjoint of mat
97      pack_rhs(blockB, rhs.getSubMapper(k2,0), actual_kc, size);
98
99      for(Index i2=0; i2<size; i2+=mc)
100      {
101        const Index actual_mc = (std::min)(i2+mc,size)-i2;
102
103        pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
104
105        // the selected actual_mc * size panel of res is split into three different part:
106        //  1 - before the diagonal => processed with gebp or skipped
107        //  2 - the actual_mc x actual_mc symmetric block => processed with a special kernel
108        //  3 - after the diagonal => processed with gebp or skipped
109        if (UpLo==Lower)
110          gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
111               (std::min)(size,i2), alpha, -1, -1, 0, 0);
112
113
114        sybb(_res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
115
116        if (UpLo==Upper)
117        {
118          Index j2 = i2+actual_mc;
119          gebp(res.getSubMapper(i2, j2), blockA, blockB+actual_kc*j2, actual_mc,
120               actual_kc, (std::max)(Index(0), size-j2), alpha, -1, -1, 0, 0);
121        }
122      }
123    }
124  }
125};
126
127// Optimized packed Block * packed Block product kernel evaluating only one given triangular part
128// This kernel is built on top of the gebp kernel:
129// - the current destination block is processed per panel of actual_mc x BlockSize
130//   where BlockSize is set to the minimal value allowing gebp to be as fast as possible
131// - then, as usual, each panel is split into three parts along the diagonal,
132//   the sub blocks above and below the diagonal are processed as usual,
133//   while the triangular block overlapping the diagonal is evaluated into a
134//   small temporary buffer which is then accumulated into the result using a
135//   triangular traversal.
136template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
137struct tribb_kernel
138{
139  typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
140  typedef typename Traits::ResScalar ResScalar;
141
142  enum {
143    BlockSize  = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret
144  };
145  void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
146  {
147    typedef blas_data_mapper<ResScalar, Index, ColMajor> ResMapper;
148    ResMapper res(_res, resStride);
149    gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel;
150
151    Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert()));
152
153    // let's process the block per panel of actual_mc x BlockSize,
154    // again, each is split into three parts, etc.
155    for (Index j=0; j<size; j+=BlockSize)
156    {
157      Index actualBlockSize = std::min<Index>(BlockSize,size - j);
158      const RhsScalar* actual_b = blockB+j*depth;
159
160      if(UpLo==Upper)
161        gebp_kernel(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha,
162                    -1, -1, 0, 0);
163
164      // selfadjoint micro block
165      {
166        Index i = j;
167        buffer.setZero();
168        // 1 - apply the kernel on the temporary buffer
169        gebp_kernel(ResMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
170                    -1, -1, 0, 0);
171        // 2 - triangular accumulation
172        for(Index j1=0; j1<actualBlockSize; ++j1)
173        {
174          ResScalar* r = &res(i, j + j1);
175          for(Index i1=UpLo==Lower ? j1 : 0;
176              UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1)
177            r[i1] += buffer(i1,j1);
178        }
179      }
180
181      if(UpLo==Lower)
182      {
183        Index i = j+actualBlockSize;
184        gebp_kernel(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i,
185                    depth, actualBlockSize, alpha, -1, -1, 0, 0);
186      }
187    }
188  }
189};
190
191} // end namespace internal
192
193// high level API
194
195template<typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct>
196struct general_product_to_triangular_selector;
197
198
199template<typename MatrixType, typename ProductType, int UpLo>
200struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
201{
202  static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta)
203  {
204    typedef typename MatrixType::Scalar Scalar;
205
206    typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
207    typedef internal::blas_traits<Lhs> LhsBlasTraits;
208    typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
209    typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
210    typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
211
212    typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
213    typedef internal::blas_traits<Rhs> RhsBlasTraits;
214    typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
215    typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
216    typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
217
218    Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
219
220    if(!beta)
221      mat.template triangularView<UpLo>().setZero();
222
223    enum {
224      StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor,
225      UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1,
226      UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1
227    };
228
229    internal::gemv_static_vector_if<Scalar,Lhs::SizeAtCompileTime,Lhs::MaxSizeAtCompileTime,!UseLhsDirectly> static_lhs;
230    ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(),
231      (UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data()));
232    if(!UseLhsDirectly) Map<typename _ActualLhs::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs;
233
234    internal::gemv_static_vector_if<Scalar,Rhs::SizeAtCompileTime,Rhs::MaxSizeAtCompileTime,!UseRhsDirectly> static_rhs;
235    ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(),
236      (UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data()));
237    if(!UseRhsDirectly) Map<typename _ActualRhs::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
238
239
240    selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
241                              LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
242                              RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex>
243          ::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha);
244  }
245};
246
247template<typename MatrixType, typename ProductType, int UpLo>
248struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
249{
250  static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta)
251  {
252    typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
253    typedef internal::blas_traits<Lhs> LhsBlasTraits;
254    typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
255    typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
256    typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
257
258    typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
259    typedef internal::blas_traits<Rhs> RhsBlasTraits;
260    typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
261    typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
262    typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
263
264    typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
265
266    if(!beta)
267      mat.template triangularView<UpLo>().setZero();
268
269    enum {
270      IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
271      LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0,
272      RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0
273    };
274
275    Index size = mat.cols();
276    Index depth = actualLhs.cols();
277
278    typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,typename Lhs::Scalar,typename Rhs::Scalar,
279          MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualRhs::MaxColsAtCompileTime> BlockingType;
280
281    BlockingType blocking(size, size, depth, 1, false);
282
283    internal::general_matrix_matrix_triangular_product<Index,
284      typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
285      typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
286      IsRowMajor ? RowMajor : ColMajor, UpLo>
287      ::run(size, depth,
288            &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
289            mat.data(), mat.outerStride(), actualAlpha, blocking);
290  }
291};
292
293template<typename MatrixType, unsigned int UpLo>
294template<typename ProductType>
295TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta)
296{
297  eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
298
299  general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta);
300
301  return derived();
302}
303
304} // end namespace Eigen
305
306#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
307