1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5// Copyright (C) 2007-2009 Benoit Jacob <jacob.benoit.1@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_DIAGONALPRODUCT_H
12#define EIGEN_DIAGONALPRODUCT_H
13
14namespace Eigen {
15
16namespace internal {
17template<typename MatrixType, typename DiagonalType, int ProductOrder>
18struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
19 : traits<MatrixType>
20{
21  typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
22  enum {
23    RowsAtCompileTime = MatrixType::RowsAtCompileTime,
24    ColsAtCompileTime = MatrixType::ColsAtCompileTime,
25    MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
26    MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
27
28    _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
29    _ScalarAccessOnDiag =  !((int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheLeft)
30                          ||(int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheRight)),
31    _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
32    // FIXME currently we need same types, but in the future the next rule should be the one
33    //_Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && ((!_PacketOnDiag) || (_SameTypes && bool(int(DiagonalType::DiagonalVectorType::Flags)&PacketAccessBit))),
34    _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && (_ScalarAccessOnDiag || (bool(int(DiagonalType::DiagonalVectorType::Flags)&PacketAccessBit))),
35    _LinearAccessMask = (RowsAtCompileTime==1 || ColsAtCompileTime==1) ? LinearAccessBit : 0,
36
37    Flags = ((HereditaryBits|_LinearAccessMask) & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0) | AlignedBit,//(int(MatrixType::Flags)&int(DiagonalType::DiagonalVectorType::Flags)&AlignedBit),
38    CoeffReadCost = NumTraits<Scalar>::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost
39  };
40};
41}
42
43template<typename MatrixType, typename DiagonalType, int ProductOrder>
44class DiagonalProduct : internal::no_assignment_operator,
45                        public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
46{
47  public:
48
49    typedef MatrixBase<DiagonalProduct> Base;
50    EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
51
52    inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
53      : m_matrix(matrix), m_diagonal(diagonal)
54    {
55      eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
56    }
57
58    EIGEN_STRONG_INLINE Index rows() const { return m_matrix.rows(); }
59    EIGEN_STRONG_INLINE Index cols() const { return m_matrix.cols(); }
60
61    EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
62    {
63      return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
64    }
65
66    EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
67    {
68      enum {
69        StorageOrder = int(MatrixType::Flags) & RowMajorBit ? RowMajor : ColMajor
70      };
71      return coeff(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
72    }
73
74    template<int LoadMode>
75    EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
76    {
77      enum {
78        StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
79      };
80      const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
81      return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
82        ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
83       ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
84    }
85
86    template<int LoadMode>
87    EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
88    {
89      enum {
90        StorageOrder = int(MatrixType::Flags) & RowMajorBit ? RowMajor : ColMajor
91      };
92      return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
93    }
94
95  protected:
96    template<int LoadMode>
97    EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
98    {
99      return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
100                     internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
101    }
102
103    template<int LoadMode>
104    EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
105    {
106      enum {
107        InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
108        DiagonalVectorPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::DiagonalVectorType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned)
109      };
110      return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
111                     m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
112    }
113
114    typename MatrixType::Nested m_matrix;
115    typename DiagonalType::Nested m_diagonal;
116};
117
118/** \returns the diagonal matrix product of \c *this by the diagonal matrix \a diagonal.
119  */
120template<typename Derived>
121template<typename DiagonalDerived>
122inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
123MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &a_diagonal) const
124{
125  return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), a_diagonal.derived());
126}
127
128} // end namespace Eigen
129
130#endif // EIGEN_DIAGONALPRODUCT_H
131