1/*
2 Copyright (c) 2011, Intel Corporation. All rights reserved.
3
4 Redistribution and use in source and binary forms, with or without modification,
5 are permitted provided that the following conditions are met:
6
7 * Redistributions of source code must retain the above copyright notice, this
8   list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright notice,
10   this list of conditions and the following disclaimer in the documentation
11   and/or other materials provided with the distribution.
12 * Neither the name of Intel Corporation nor the names of its contributors may
13   be used to endorse or promote products derived from this software without
14   specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27 ********************************************************************************
28 *   Content : Eigen bindings to Intel(R) MKL
29 *   Triangular matrix * matrix product functionality based on ?TRMM.
30 ********************************************************************************
31*/
32
33#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
34#define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
35
36namespace Eigen {
37
38namespace internal {
39
40
41template <typename Scalar, typename Index,
42          int Mode, bool LhsIsTriangular,
43          int LhsStorageOrder, bool ConjugateLhs,
44          int RhsStorageOrder, bool ConjugateRhs,
45          int ResStorageOrder>
46struct product_triangular_matrix_matrix_trmm :
47       product_triangular_matrix_matrix<Scalar,Index,Mode,
48          LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49          RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
50
51
52// try to go to BLAS specialization
53#define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
54template <typename Index, int Mode, \
55          int LhsStorageOrder, bool ConjugateLhs, \
56          int RhsStorageOrder, bool ConjugateRhs> \
57struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
58           LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
59  static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
60    const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
61      product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
62        LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
63        RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
64        _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
65  } \
66};
67
68EIGEN_MKL_TRMM_SPECIALIZE(double, true)
69EIGEN_MKL_TRMM_SPECIALIZE(double, false)
70EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
71EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
72EIGEN_MKL_TRMM_SPECIALIZE(float, true)
73EIGEN_MKL_TRMM_SPECIALIZE(float, false)
74EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
75EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
76
77// implements col-major += alpha * op(triangular) * op(general)
78#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
79template <typename Index, int Mode, \
80          int LhsStorageOrder, bool ConjugateLhs, \
81          int RhsStorageOrder, bool ConjugateRhs> \
82struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
83         LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
84{ \
85  enum { \
86    IsLower = (Mode&Lower) == Lower, \
87    SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
88    IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
89    IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
90    LowUp = IsLower ? Lower : Upper, \
91    conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
92  }; \
93\
94  static void run( \
95    Index _rows, Index _cols, Index _depth, \
96    const EIGTYPE* _lhs, Index lhsStride, \
97    const EIGTYPE* _rhs, Index rhsStride, \
98    EIGTYPE* res,        Index resStride, \
99    EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
100  { \
101   Index diagSize  = (std::min)(_rows,_depth); \
102   Index rows      = IsLower ? _rows : diagSize; \
103   Index depth     = IsLower ? diagSize : _depth; \
104   Index cols      = _cols; \
105\
106   typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
107   typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
108\
109/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
110   if (rows != depth) { \
111\
112     int nthr = mkl_domain_get_max_threads(MKL_BLAS); \
113\
114     if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
115     /* Most likely no benefit to call TRMM or GEMM from MKL*/ \
116       product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
117       LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
118           _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
119     /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
120     } else { \
121     /* Make sense to call GEMM */ \
122       Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
123       MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
124       MKL_INT aStride = aa_tmp.outerStride(); \
125       gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
126       general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
127       rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
128\
129     /*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
130     } \
131     return; \
132   } \
133   char side = 'L', transa, uplo, diag = 'N'; \
134   EIGTYPE *b; \
135   const EIGTYPE *a; \
136   MKL_INT m, n, lda, ldb; \
137   MKLTYPE alpha_; \
138\
139/* Set alpha_*/ \
140   assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
141\
142/* Set m, n */ \
143   m = (MKL_INT)diagSize; \
144   n = (MKL_INT)cols; \
145\
146/* Set trans */ \
147   transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
148\
149/* Set b, ldb */ \
150   Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
151   MatrixX##EIGPREFIX b_tmp; \
152\
153   if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
154   b = b_tmp.data(); \
155   ldb = b_tmp.outerStride(); \
156\
157/* Set uplo */ \
158   uplo = IsLower ? 'L' : 'U'; \
159   if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
160/* Set a, lda */ \
161   Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
162   MatrixLhs a_tmp; \
163\
164   if ((conjA!=0) || (SetDiag==0)) { \
165     if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
166     if (IsZeroDiag) \
167       a_tmp.diagonal().setZero(); \
168     else if (IsUnitDiag) \
169       a_tmp.diagonal().setOnes();\
170     a = a_tmp.data(); \
171     lda = a_tmp.outerStride(); \
172   } else { \
173     a = _lhs; \
174     lda = lhsStride; \
175   } \
176   /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
177/* call ?trmm*/ \
178   MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
179\
180/* Add op(a_triangular)*b into res*/ \
181   Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
182   res_tmp=res_tmp+b_tmp; \
183  } \
184};
185
186EIGEN_MKL_TRMM_L(double, double, d, d)
187EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
188EIGEN_MKL_TRMM_L(float, float, f, s)
189EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
190
191// implements col-major += alpha * op(general) * op(triangular)
192#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
193template <typename Index, int Mode, \
194          int LhsStorageOrder, bool ConjugateLhs, \
195          int RhsStorageOrder, bool ConjugateRhs> \
196struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
197         LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
198{ \
199  enum { \
200    IsLower = (Mode&Lower) == Lower, \
201    SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
202    IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
203    IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
204    LowUp = IsLower ? Lower : Upper, \
205    conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
206  }; \
207\
208  static void run( \
209    Index _rows, Index _cols, Index _depth, \
210    const EIGTYPE* _lhs, Index lhsStride, \
211    const EIGTYPE* _rhs, Index rhsStride, \
212    EIGTYPE* res,        Index resStride, \
213    EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
214  { \
215   Index diagSize  = (std::min)(_cols,_depth); \
216   Index rows      = _rows; \
217   Index depth     = IsLower ? _depth : diagSize; \
218   Index cols      = IsLower ? diagSize : _cols; \
219\
220   typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
221   typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
222\
223/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
224   if (cols != depth) { \
225\
226     int nthr = mkl_domain_get_max_threads(MKL_BLAS); \
227\
228     if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
229     /* Most likely no benefit to call TRMM or GEMM from MKL*/ \
230       product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
231       LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
232           _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
233       /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
234     } else { \
235     /* Make sense to call GEMM */ \
236       Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
237       MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
238       MKL_INT aStride = aa_tmp.outerStride(); \
239       gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
240       general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
241       rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
242\
243     /*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
244     } \
245     return; \
246   } \
247   char side = 'R', transa, uplo, diag = 'N'; \
248   EIGTYPE *b; \
249   const EIGTYPE *a; \
250   MKL_INT m, n, lda, ldb; \
251   MKLTYPE alpha_; \
252\
253/* Set alpha_*/ \
254   assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
255\
256/* Set m, n */ \
257   m = (MKL_INT)rows; \
258   n = (MKL_INT)diagSize; \
259\
260/* Set trans */ \
261   transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
262\
263/* Set b, ldb */ \
264   Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
265   MatrixX##EIGPREFIX b_tmp; \
266\
267   if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
268   b = b_tmp.data(); \
269   ldb = b_tmp.outerStride(); \
270\
271/* Set uplo */ \
272   uplo = IsLower ? 'L' : 'U'; \
273   if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
274/* Set a, lda */ \
275   Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
276   MatrixRhs a_tmp; \
277\
278   if ((conjA!=0) || (SetDiag==0)) { \
279     if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
280     if (IsZeroDiag) \
281       a_tmp.diagonal().setZero(); \
282     else if (IsUnitDiag) \
283       a_tmp.diagonal().setOnes();\
284     a = a_tmp.data(); \
285     lda = a_tmp.outerStride(); \
286   } else { \
287     a = _rhs; \
288     lda = rhsStride; \
289   } \
290   /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
291/* call ?trmm*/ \
292   MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
293\
294/* Add op(a_triangular)*b into res*/ \
295   Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
296   res_tmp=res_tmp+b_tmp; \
297  } \
298};
299
300EIGEN_MKL_TRMM_R(double, double, d, d)
301EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
302EIGEN_MKL_TRMM_R(float, float, f, s)
303EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
304
305} // end namespace internal
306
307} // end namespace Eigen
308
309#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
310