1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_H
11#define EIGEN_SELFADJOINT_MATRIX_MATRIX_H
12
13namespace Eigen {
14
15namespace internal {
16
17// pack a selfadjoint block diagonal for use with the gebp_kernel
18template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder>
19struct symm_pack_lhs
20{
21  template<int BlockRows> inline
22  void pack(Scalar* blockA, const const_blas_data_mapper<Scalar,Index,StorageOrder>& lhs, Index cols, Index i, Index& count)
23  {
24    // normal copy
25    for(Index k=0; k<i; k++)
26      for(Index w=0; w<BlockRows; w++)
27        blockA[count++] = lhs(i+w,k);           // normal
28    // symmetric copy
29    Index h = 0;
30    for(Index k=i; k<i+BlockRows; k++)
31    {
32      for(Index w=0; w<h; w++)
33        blockA[count++] = conj(lhs(k, i+w)); // transposed
34
35      blockA[count++] = real(lhs(k,k));   // real (diagonal)
36
37      for(Index w=h+1; w<BlockRows; w++)
38        blockA[count++] = lhs(i+w, k);          // normal
39      ++h;
40    }
41    // transposed copy
42    for(Index k=i+BlockRows; k<cols; k++)
43      for(Index w=0; w<BlockRows; w++)
44        blockA[count++] = conj(lhs(k, i+w)); // transposed
45  }
46  void operator()(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
47  {
48    const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(_lhs,lhsStride);
49    Index count = 0;
50    Index peeled_mc = (rows/Pack1)*Pack1;
51    for(Index i=0; i<peeled_mc; i+=Pack1)
52    {
53      pack<Pack1>(blockA, lhs, cols, i, count);
54    }
55
56    if(rows-peeled_mc>=Pack2)
57    {
58      pack<Pack2>(blockA, lhs, cols, peeled_mc, count);
59      peeled_mc += Pack2;
60    }
61
62    // do the same with mr==1
63    for(Index i=peeled_mc; i<rows; i++)
64    {
65      for(Index k=0; k<i; k++)
66        blockA[count++] = lhs(i, k);              // normal
67
68      blockA[count++] = real(lhs(i, i));       // real (diagonal)
69
70      for(Index k=i+1; k<cols; k++)
71        blockA[count++] = conj(lhs(k, i));     // transposed
72    }
73  }
74};
75
76template<typename Scalar, typename Index, int nr, int StorageOrder>
77struct symm_pack_rhs
78{
79  enum { PacketSize = packet_traits<Scalar>::size };
80  void operator()(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
81  {
82    Index end_k = k2 + rows;
83    Index count = 0;
84    const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(_rhs,rhsStride);
85    Index packet_cols = (cols/nr)*nr;
86
87    // first part: normal case
88    for(Index j2=0; j2<k2; j2+=nr)
89    {
90      for(Index k=k2; k<end_k; k++)
91      {
92        blockB[count+0] = rhs(k,j2+0);
93        blockB[count+1] = rhs(k,j2+1);
94        if (nr==4)
95        {
96          blockB[count+2] = rhs(k,j2+2);
97          blockB[count+3] = rhs(k,j2+3);
98        }
99        count += nr;
100      }
101    }
102
103    // second part: diagonal block
104    for(Index j2=k2; j2<(std::min)(k2+rows,packet_cols); j2+=nr)
105    {
106      // again we can split vertically in three different parts (transpose, symmetric, normal)
107      // transpose
108      for(Index k=k2; k<j2; k++)
109      {
110        blockB[count+0] = conj(rhs(j2+0,k));
111        blockB[count+1] = conj(rhs(j2+1,k));
112        if (nr==4)
113        {
114          blockB[count+2] = conj(rhs(j2+2,k));
115          blockB[count+3] = conj(rhs(j2+3,k));
116        }
117        count += nr;
118      }
119      // symmetric
120      Index h = 0;
121      for(Index k=j2; k<j2+nr; k++)
122      {
123        // normal
124        for (Index w=0 ; w<h; ++w)
125          blockB[count+w] = rhs(k,j2+w);
126
127        blockB[count+h] = real(rhs(k,k));
128
129        // transpose
130        for (Index w=h+1 ; w<nr; ++w)
131          blockB[count+w] = conj(rhs(j2+w,k));
132        count += nr;
133        ++h;
134      }
135      // normal
136      for(Index k=j2+nr; k<end_k; k++)
137      {
138        blockB[count+0] = rhs(k,j2+0);
139        blockB[count+1] = rhs(k,j2+1);
140        if (nr==4)
141        {
142          blockB[count+2] = rhs(k,j2+2);
143          blockB[count+3] = rhs(k,j2+3);
144        }
145        count += nr;
146      }
147    }
148
149    // third part: transposed
150    for(Index j2=k2+rows; j2<packet_cols; j2+=nr)
151    {
152      for(Index k=k2; k<end_k; k++)
153      {
154        blockB[count+0] = conj(rhs(j2+0,k));
155        blockB[count+1] = conj(rhs(j2+1,k));
156        if (nr==4)
157        {
158          blockB[count+2] = conj(rhs(j2+2,k));
159          blockB[count+3] = conj(rhs(j2+3,k));
160        }
161        count += nr;
162      }
163    }
164
165    // copy the remaining columns one at a time (=> the same with nr==1)
166    for(Index j2=packet_cols; j2<cols; ++j2)
167    {
168      // transpose
169      Index half = (std::min)(end_k,j2);
170      for(Index k=k2; k<half; k++)
171      {
172        blockB[count] = conj(rhs(j2,k));
173        count += 1;
174      }
175
176      if(half==j2 && half<k2+rows)
177      {
178        blockB[count] = real(rhs(j2,j2));
179        count += 1;
180      }
181      else
182        half--;
183
184      // normal
185      for(Index k=half+1; k<k2+rows; k++)
186      {
187        blockB[count] = rhs(k,j2);
188        count += 1;
189      }
190    }
191  }
192};
193
194/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of
195 * the general matrix matrix product.
196 */
197template <typename Scalar, typename Index,
198          int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
199          int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
200          int ResStorageOrder>
201struct product_selfadjoint_matrix;
202
203template <typename Scalar, typename Index,
204          int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
205          int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs>
206struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor>
207{
208
209  static EIGEN_STRONG_INLINE void run(
210    Index rows, Index cols,
211    const Scalar* lhs, Index lhsStride,
212    const Scalar* rhs, Index rhsStride,
213    Scalar* res,       Index resStride,
214    Scalar alpha)
215  {
216    product_selfadjoint_matrix<Scalar, Index,
217      EIGEN_LOGICAL_XOR(RhsSelfAdjoint,RhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
218      RhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsSelfAdjoint,ConjugateRhs),
219      EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
220      LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs),
221      ColMajor>
222      ::run(cols, rows,  rhs, rhsStride,  lhs, lhsStride,  res, resStride,  alpha);
223  }
224};
225
226template <typename Scalar, typename Index,
227          int LhsStorageOrder, bool ConjugateLhs,
228          int RhsStorageOrder, bool ConjugateRhs>
229struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor>
230{
231
232  static EIGEN_DONT_INLINE void run(
233    Index rows, Index cols,
234    const Scalar* _lhs, Index lhsStride,
235    const Scalar* _rhs, Index rhsStride,
236    Scalar* res,        Index resStride,
237    Scalar alpha)
238  {
239    Index size = rows;
240
241    const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
242    const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
243
244    typedef gebp_traits<Scalar,Scalar> Traits;
245
246    Index kc = size;  // cache block size along the K direction
247    Index mc = rows;  // cache block size along the M direction
248    Index nc = cols;  // cache block size along the N direction
249    computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc);
250    // kc must smaller than mc
251    kc = (std::min)(kc,mc);
252
253    std::size_t sizeW = kc*Traits::WorkSpaceFactor;
254    std::size_t sizeB = sizeW + kc*cols;
255    ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
256    ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
257    Scalar* blockB = allocatedBlockB + sizeW;
258
259    gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
260    symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
261    gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
262    gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed;
263
264    for(Index k2=0; k2<size; k2+=kc)
265    {
266      const Index actual_kc = (std::min)(k2+kc,size)-k2;
267
268      // we have selected one row panel of rhs and one column panel of lhs
269      // pack rhs's panel into a sequential chunk of memory
270      // and expand each coeff to a constant packet for further reuse
271      pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols);
272
273      // the select lhs's panel has to be split in three different parts:
274      //  1 - the transposed panel above the diagonal block => transposed packed copy
275      //  2 - the diagonal block => special packed copy
276      //  3 - the panel below the diagonal block => generic packed copy
277      for(Index i2=0; i2<k2; i2+=mc)
278      {
279        const Index actual_mc = (std::min)(i2+mc,k2)-i2;
280        // transposed packed copy
281        pack_lhs_transposed(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc);
282
283        gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
284      }
285      // the block diagonal
286      {
287        const Index actual_mc = (std::min)(k2+kc,size)-k2;
288        // symmetric packed copy
289        pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
290
291        gebp_kernel(res+k2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
292      }
293
294      for(Index i2=k2+kc; i2<size; i2+=mc)
295      {
296        const Index actual_mc = (std::min)(i2+mc,size)-i2;
297        gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder,false>()
298          (blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
299
300        gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
301      }
302    }
303  }
304};
305
306// matrix * selfadjoint product
307template <typename Scalar, typename Index,
308          int LhsStorageOrder, bool ConjugateLhs,
309          int RhsStorageOrder, bool ConjugateRhs>
310struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor>
311{
312
313  static EIGEN_DONT_INLINE void run(
314    Index rows, Index cols,
315    const Scalar* _lhs, Index lhsStride,
316    const Scalar* _rhs, Index rhsStride,
317    Scalar* res,        Index resStride,
318    Scalar alpha)
319  {
320    Index size = cols;
321
322    const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
323
324    typedef gebp_traits<Scalar,Scalar> Traits;
325
326    Index kc = size; // cache block size along the K direction
327    Index mc = rows;  // cache block size along the M direction
328    Index nc = cols;  // cache block size along the N direction
329    computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc);
330    std::size_t sizeW = kc*Traits::WorkSpaceFactor;
331    std::size_t sizeB = sizeW + kc*cols;
332    ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
333    ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
334    Scalar* blockB = allocatedBlockB + sizeW;
335
336    gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
337    gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
338    symm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
339
340    for(Index k2=0; k2<size; k2+=kc)
341    {
342      const Index actual_kc = (std::min)(k2+kc,size)-k2;
343
344      pack_rhs(blockB, _rhs, rhsStride, actual_kc, cols, k2);
345
346      // => GEPP
347      for(Index i2=0; i2<rows; i2+=mc)
348      {
349        const Index actual_mc = (std::min)(i2+mc,rows)-i2;
350        pack_lhs(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
351
352        gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
353      }
354    }
355  }
356};
357
358} // end namespace internal
359
360/***************************************************************************
361* Wrapper to product_selfadjoint_matrix
362***************************************************************************/
363
364namespace internal {
365template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
366struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> >
367  : traits<ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>, Lhs, Rhs> >
368{};
369}
370
371template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
372struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
373  : public ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>, Lhs, Rhs >
374{
375  EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix)
376
377  SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
378
379  enum {
380    LhsIsUpper = (LhsMode&(Upper|Lower))==Upper,
381    LhsIsSelfAdjoint = (LhsMode&SelfAdjoint)==SelfAdjoint,
382    RhsIsUpper = (RhsMode&(Upper|Lower))==Upper,
383    RhsIsSelfAdjoint = (RhsMode&SelfAdjoint)==SelfAdjoint
384  };
385
386  template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
387  {
388    eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
389
390    typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs);
391    typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs);
392
393    Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
394                               * RhsBlasTraits::extractScalarFactor(m_rhs);
395
396    internal::product_selfadjoint_matrix<Scalar, Index,
397      EIGEN_LOGICAL_XOR(LhsIsUpper,
398                        internal::traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
399      NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)),
400      EIGEN_LOGICAL_XOR(RhsIsUpper,
401                        internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
402      NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)),
403      internal::traits<Dest>::Flags&RowMajorBit  ? RowMajor : ColMajor>
404      ::run(
405        lhs.rows(), rhs.cols(),                 // sizes
406        &lhs.coeffRef(0,0),    lhs.outerStride(),  // lhs info
407        &rhs.coeffRef(0,0),    rhs.outerStride(),  // rhs info
408        &dst.coeffRef(0,0), dst.outerStride(),  // result info
409        actualAlpha                             // alpha
410      );
411  }
412};
413
414} // end namespace Eigen
415
416#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H
417