1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_BLASUTIL_H
11#define EIGEN_BLASUTIL_H
12
13// This file contains many lightweight helper classes used to
14// implement and control fast level 2 and level 3 BLAS-like routines.
15
16namespace Eigen {
17
18namespace internal {
19
20// forward declarations
21template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
22struct gebp_kernel;
23
24template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
25struct gemm_pack_rhs;
26
27template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
28struct gemm_pack_lhs;
29
30template<
31  typename Index,
32  typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
33  typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
34  int ResStorageOrder>
35struct general_matrix_matrix_product;
36
37template<typename Index,
38         typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
39         typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
40struct general_matrix_vector_product;
41
42
43template<bool Conjugate> struct conj_if;
44
45template<> struct conj_if<true> {
46  template<typename T>
47  inline T operator()(const T& x) const { return numext::conj(x); }
48  template<typename T>
49  inline T pconj(const T& x) const { return internal::pconj(x); }
50};
51
52template<> struct conj_if<false> {
53  template<typename T>
54  inline const T& operator()(const T& x) const { return x; }
55  template<typename T>
56  inline const T& pconj(const T& x) const { return x; }
57};
58
59// Generic implementation for custom complex types.
60template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs>
61struct conj_helper
62{
63  typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;
64
65  EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const
66  { return padd(c, pmul(x,y)); }
67
68  EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const
69  { return conj_if<ConjLhs>()(x) *  conj_if<ConjRhs>()(y); }
70};
71
72template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
73{
74  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
75  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
76};
77
78template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
79{
80  typedef std::complex<RealScalar> Scalar;
81  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
82  { return c + pmul(x,y); }
83
84  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
85  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
86};
87
88template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
89{
90  typedef std::complex<RealScalar> Scalar;
91  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
92  { return c + pmul(x,y); }
93
94  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
95  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
96};
97
98template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
99{
100  typedef std::complex<RealScalar> Scalar;
101  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
102  { return c + pmul(x,y); }
103
104  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
105  { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
106};
107
108template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
109{
110  typedef std::complex<RealScalar> Scalar;
111  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
112  { return padd(c, pmul(x,y)); }
113  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
114  { return conj_if<Conj>()(x)*y; }
115};
116
117template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
118{
119  typedef std::complex<RealScalar> Scalar;
120  EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
121  { return padd(c, pmul(x,y)); }
122  EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
123  { return x*conj_if<Conj>()(y); }
124};
125
126template<typename From,typename To> struct get_factor {
127  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
128};
129
130template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
131  EIGEN_DEVICE_FUNC
132  static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
133};
134
135
136template<typename Scalar, typename Index>
137class BlasVectorMapper {
138  public:
139  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
140
141  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
142    return m_data[i];
143  }
144  template <typename Packet, int AlignmentType>
145  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
146    return ploadt<Packet, AlignmentType>(m_data + i);
147  }
148
149  template <typename Packet>
150  EIGEN_DEVICE_FUNC bool aligned(Index i) const {
151    return (UIntPtr(m_data+i)%sizeof(Packet))==0;
152  }
153
154  protected:
155  Scalar* m_data;
156};
157
158template<typename Scalar, typename Index, int AlignmentType>
159class BlasLinearMapper {
160  public:
161  typedef typename packet_traits<Scalar>::type Packet;
162  typedef typename packet_traits<Scalar>::half HalfPacket;
163
164  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
165
166  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
167    internal::prefetch(&operator()(i));
168  }
169
170  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
171    return m_data[i];
172  }
173
174  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
175    return ploadt<Packet, AlignmentType>(m_data + i);
176  }
177
178  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
179    return ploadt<HalfPacket, AlignmentType>(m_data + i);
180  }
181
182  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
183    pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
184  }
185
186  protected:
187  Scalar *m_data;
188};
189
190// Lightweight helper class to access matrix coefficients.
191template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
192class blas_data_mapper {
193  public:
194  typedef typename packet_traits<Scalar>::type Packet;
195  typedef typename packet_traits<Scalar>::half HalfPacket;
196
197  typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
198  typedef BlasVectorMapper<Scalar, Index> VectorMapper;
199
200  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
201
202  EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
203  getSubMapper(Index i, Index j) const {
204    return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
205  }
206
207  EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
208    return LinearMapper(&operator()(i, j));
209  }
210
211  EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
212    return VectorMapper(&operator()(i, j));
213  }
214
215
216  EIGEN_DEVICE_FUNC
217  EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
218    return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
219  }
220
221  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
222    return ploadt<Packet, AlignmentType>(&operator()(i, j));
223  }
224
225  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
226    return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
227  }
228
229  template<typename SubPacket>
230  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
231    pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
232  }
233
234  template<typename SubPacket>
235  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
236    return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
237  }
238
239  EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
240  EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
241
242  EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
243    if (UIntPtr(m_data)%sizeof(Scalar)) {
244      return -1;
245    }
246    return internal::first_default_aligned(m_data, size);
247  }
248
249  protected:
250  Scalar* EIGEN_RESTRICT m_data;
251  const Index m_stride;
252};
253
254// lightweight helper class to access matrix coefficients (const version)
255template<typename Scalar, typename Index, int StorageOrder>
256class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
257  public:
258  EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
259
260  EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
261    return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
262  }
263};
264
265
266/* Helper class to analyze the factors of a Product expression.
267 * In particular it allows to pop out operator-, scalar multiples,
268 * and conjugate */
269template<typename XprType> struct blas_traits
270{
271  typedef typename traits<XprType>::Scalar Scalar;
272  typedef const XprType& ExtractType;
273  typedef XprType _ExtractType;
274  enum {
275    IsComplex = NumTraits<Scalar>::IsComplex,
276    IsTransposed = false,
277    NeedToConjugate = false,
278    HasUsableDirectAccess = (    (int(XprType::Flags)&DirectAccessBit)
279                              && (   bool(XprType::IsVectorAtCompileTime)
280                                  || int(inner_stride_at_compile_time<XprType>::ret) == 1)
281                             ) ?  1 : 0
282  };
283  typedef typename conditional<bool(HasUsableDirectAccess),
284    ExtractType,
285    typename _ExtractType::PlainObject
286    >::type DirectLinearAccessType;
287  static inline ExtractType extract(const XprType& x) { return x; }
288  static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
289};
290
291// pop conjugate
292template<typename Scalar, typename NestedXpr>
293struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
294 : blas_traits<NestedXpr>
295{
296  typedef blas_traits<NestedXpr> Base;
297  typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
298  typedef typename Base::ExtractType ExtractType;
299
300  enum {
301    IsComplex = NumTraits<Scalar>::IsComplex,
302    NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
303  };
304  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
305  static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
306};
307
308// pop scalar multiple
309template<typename Scalar, typename NestedXpr, typename Plain>
310struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
311 : blas_traits<NestedXpr>
312{
313  typedef blas_traits<NestedXpr> Base;
314  typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
315  typedef typename Base::ExtractType ExtractType;
316  static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
317  static inline Scalar extractScalarFactor(const XprType& x)
318  { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
319};
320template<typename Scalar, typename NestedXpr, typename Plain>
321struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
322 : blas_traits<NestedXpr>
323{
324  typedef blas_traits<NestedXpr> Base;
325  typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
326  typedef typename Base::ExtractType ExtractType;
327  static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
328  static inline Scalar extractScalarFactor(const XprType& x)
329  { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
330};
331template<typename Scalar, typename Plain1, typename Plain2>
332struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
333                                                            const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
334 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
335{};
336
337// pop opposite
338template<typename Scalar, typename NestedXpr>
339struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
340 : blas_traits<NestedXpr>
341{
342  typedef blas_traits<NestedXpr> Base;
343  typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
344  typedef typename Base::ExtractType ExtractType;
345  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
346  static inline Scalar extractScalarFactor(const XprType& x)
347  { return - Base::extractScalarFactor(x.nestedExpression()); }
348};
349
350// pop/push transpose
351template<typename NestedXpr>
352struct blas_traits<Transpose<NestedXpr> >
353 : blas_traits<NestedXpr>
354{
355  typedef typename NestedXpr::Scalar Scalar;
356  typedef blas_traits<NestedXpr> Base;
357  typedef Transpose<NestedXpr> XprType;
358  typedef Transpose<const typename Base::_ExtractType>  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
359  typedef Transpose<const typename Base::_ExtractType> _ExtractType;
360  typedef typename conditional<bool(Base::HasUsableDirectAccess),
361    ExtractType,
362    typename ExtractType::PlainObject
363    >::type DirectLinearAccessType;
364  enum {
365    IsTransposed = Base::IsTransposed ? 0 : 1
366  };
367  static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
368  static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
369};
370
371template<typename T>
372struct blas_traits<const T>
373     : blas_traits<T>
374{};
375
376template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
377struct extract_data_selector {
378  static const typename T::Scalar* run(const T& m)
379  {
380    return blas_traits<T>::extract(m).data();
381  }
382};
383
384template<typename T>
385struct extract_data_selector<T,false> {
386  static typename T::Scalar* run(const T&) { return 0; }
387};
388
389template<typename T> const typename T::Scalar* extract_data(const T& m)
390{
391  return extract_data_selector<T>::run(m);
392}
393
394} // end namespace internal
395
396} // end namespace Eigen
397
398#endif // EIGEN_BLASUTIL_H
399