DiagonalProduct.h revision c981c48f5bc9aefeffc0bcb0cc3934c2fae179dd
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5// Copyright (C) 2007-2009 Benoit Jacob <jacob.benoit.1@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_DIAGONALPRODUCT_H
12#define EIGEN_DIAGONALPRODUCT_H
13
14namespace Eigen {
15
16namespace internal {
17template<typename MatrixType, typename DiagonalType, int ProductOrder>
18struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
19 : traits<MatrixType>
20{
21  typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
22  enum {
23    RowsAtCompileTime = MatrixType::RowsAtCompileTime,
24    ColsAtCompileTime = MatrixType::ColsAtCompileTime,
25    MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
26    MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
27
28    _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
29    _PacketOnDiag = !((int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
30                    ||(int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)),
31    _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
32    // FIXME currently we need same types, but in the future the next rule should be the one
33    //_Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && ((!_PacketOnDiag) || (_SameTypes && bool(int(DiagonalType::Flags)&PacketAccessBit))),
34    _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && ((!_PacketOnDiag) || (bool(int(DiagonalType::Flags)&PacketAccessBit))),
35
36    Flags = (HereditaryBits & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0),
37    CoeffReadCost = NumTraits<Scalar>::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost
38  };
39};
40}
41
42template<typename MatrixType, typename DiagonalType, int ProductOrder>
43class DiagonalProduct : internal::no_assignment_operator,
44                        public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
45{
46  public:
47
48    typedef MatrixBase<DiagonalProduct> Base;
49    EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
50
51    inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
52      : m_matrix(matrix), m_diagonal(diagonal)
53    {
54      eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
55    }
56
57    inline Index rows() const { return m_matrix.rows(); }
58    inline Index cols() const { return m_matrix.cols(); }
59
60    const Scalar coeff(Index row, Index col) const
61    {
62      return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
63    }
64
65    template<int LoadMode>
66    EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
67    {
68      enum {
69        StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
70      };
71      const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
72
73      return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
74        ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
75       ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
76    }
77
78  protected:
79    template<int LoadMode>
80    EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
81    {
82      return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
83                     internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
84    }
85
86    template<int LoadMode>
87    EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
88    {
89      enum {
90        InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
91        DiagonalVectorPacketLoadMode = (LoadMode == Aligned && ((InnerSize%16) == 0)) ? Aligned : Unaligned
92      };
93      return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
94                     m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
95    }
96
97    typename MatrixType::Nested m_matrix;
98    typename DiagonalType::Nested m_diagonal;
99};
100
101/** \returns the diagonal matrix product of \c *this by the diagonal matrix \a diagonal.
102  */
103template<typename Derived>
104template<typename DiagonalDerived>
105inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
106MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &diagonal) const
107{
108  return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), diagonal.derived());
109}
110
111/** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
112  */
113template<typename DiagonalDerived>
114template<typename MatrixDerived>
115inline const DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>
116DiagonalBase<DiagonalDerived>::operator*(const MatrixBase<MatrixDerived> &matrix) const
117{
118  return DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>(matrix.derived(), derived());
119}
120
121} // end namespace Eigen
122
123#endif // EIGEN_DIAGONALPRODUCT_H
124