1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
11#define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
12
13namespace Eigen {
14
15// The product of a diagonal matrix with a sparse matrix can be easily
16// implemented using expression template.
17// We have two consider very different cases:
18// 1 - diag * row-major sparse
19//     => each inner vector <=> scalar * sparse vector product
20//     => so we can reuse CwiseUnaryOp::InnerIterator
21// 2 - diag * col-major sparse
22//     => each inner vector <=> densevector * sparse vector cwise product
23//     => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
24//        for that particular case
25// The two other cases are symmetric.
26
27namespace internal {
28
29template<typename Lhs, typename Rhs>
30struct traits<SparseDiagonalProduct<Lhs, Rhs> >
31{
32  typedef typename remove_all<Lhs>::type _Lhs;
33  typedef typename remove_all<Rhs>::type _Rhs;
34  typedef typename _Lhs::Scalar Scalar;
35  typedef typename promote_index_type<typename traits<Lhs>::Index,
36                                         typename traits<Rhs>::Index>::type Index;
37  typedef Sparse StorageKind;
38  typedef MatrixXpr XprKind;
39  enum {
40    RowsAtCompileTime = _Lhs::RowsAtCompileTime,
41    ColsAtCompileTime = _Rhs::ColsAtCompileTime,
42
43    MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime,
44    MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime,
45
46    SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags),
47    Flags = (SparseFlags&RowMajorBit),
48    CoeffReadCost = Dynamic
49  };
50};
51
52enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor};
53template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode>
54class sparse_diagonal_product_inner_iterator_selector;
55
56} // end namespace internal
57
58template<typename Lhs, typename Rhs>
59class SparseDiagonalProduct
60  : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >,
61    internal::no_assignment_operator
62{
63    typedef typename Lhs::Nested LhsNested;
64    typedef typename Rhs::Nested RhsNested;
65
66    typedef typename internal::remove_all<LhsNested>::type _LhsNested;
67    typedef typename internal::remove_all<RhsNested>::type _RhsNested;
68
69    enum {
70      LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal
71              : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor,
72      RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal
73              : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor
74    };
75
76  public:
77
78    EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct)
79
80    typedef internal::sparse_diagonal_product_inner_iterator_selector
81                <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator;
82
83    EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs)
84      : m_lhs(lhs), m_rhs(rhs)
85    {
86      eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product");
87    }
88
89    EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); }
90    EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); }
91
92    EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
93    EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
94
95  protected:
96    LhsNested m_lhs;
97    RhsNested m_rhs;
98};
99
100namespace internal {
101
102template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
103class sparse_diagonal_product_inner_iterator_selector
104<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
105  : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator
106{
107    typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base;
108    typedef typename Lhs::Index Index;
109  public:
110    inline sparse_diagonal_product_inner_iterator_selector(
111              const SparseDiagonalProductType& expr, Index outer)
112      : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer)
113    {}
114};
115
116template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
117class sparse_diagonal_product_inner_iterator_selector
118<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor>
119  : public CwiseBinaryOp<
120      scalar_product_op<typename Lhs::Scalar>,
121      SparseInnerVectorSet<Rhs,1>,
122      typename Lhs::DiagonalVectorType>::InnerIterator
123{
124    typedef typename CwiseBinaryOp<
125      scalar_product_op<typename Lhs::Scalar>,
126      SparseInnerVectorSet<Rhs,1>,
127      typename Lhs::DiagonalVectorType>::InnerIterator Base;
128    typedef typename Lhs::Index Index;
129  public:
130    inline sparse_diagonal_product_inner_iterator_selector(
131              const SparseDiagonalProductType& expr, Index outer)
132      : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0)
133    {}
134};
135
136template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
137class sparse_diagonal_product_inner_iterator_selector
138<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal>
139  : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator
140{
141    typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base;
142    typedef typename Lhs::Index Index;
143  public:
144    inline sparse_diagonal_product_inner_iterator_selector(
145              const SparseDiagonalProductType& expr, Index outer)
146      : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer)
147    {}
148};
149
150template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
151class sparse_diagonal_product_inner_iterator_selector
152<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal>
153  : public CwiseBinaryOp<
154      scalar_product_op<typename Rhs::Scalar>,
155      SparseInnerVectorSet<Lhs,1>,
156      Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator
157{
158    typedef typename CwiseBinaryOp<
159      scalar_product_op<typename Rhs::Scalar>,
160      SparseInnerVectorSet<Lhs,1>,
161      Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base;
162    typedef typename Lhs::Index Index;
163  public:
164    inline sparse_diagonal_product_inner_iterator_selector(
165              const SparseDiagonalProductType& expr, Index outer)
166      : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0)
167    {}
168};
169
170} // end namespace internal
171
172// SparseMatrixBase functions
173
174template<typename Derived>
175template<typename OtherDerived>
176const SparseDiagonalProduct<Derived,OtherDerived>
177SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const
178{
179  return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
180}
181
182} // end namespace Eigen
183
184#endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
185