1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
11#define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
12
13namespace Eigen {
14
15// The product of a diagonal matrix with a sparse matrix can be easily
16// implemented using expression template.
17// We have two consider very different cases:
18// 1 - diag * row-major sparse
19//     => each inner vector <=> scalar * sparse vector product
20//     => so we can reuse CwiseUnaryOp::InnerIterator
21// 2 - diag * col-major sparse
22//     => each inner vector <=> densevector * sparse vector cwise product
23//     => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
24//        for that particular case
25// The two other cases are symmetric.
26
27namespace internal {
28
29template<typename Lhs, typename Rhs>
30struct traits<SparseDiagonalProduct<Lhs, Rhs> >
31{
32  typedef typename remove_all<Lhs>::type _Lhs;
33  typedef typename remove_all<Rhs>::type _Rhs;
34  typedef typename _Lhs::Scalar Scalar;
35  typedef typename promote_index_type<typename traits<Lhs>::Index,
36                                         typename traits<Rhs>::Index>::type Index;
37  typedef Sparse StorageKind;
38  typedef MatrixXpr XprKind;
39  enum {
40    RowsAtCompileTime = _Lhs::RowsAtCompileTime,
41    ColsAtCompileTime = _Rhs::ColsAtCompileTime,
42
43    MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime,
44    MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime,
45
46    SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags),
47    Flags = (SparseFlags&RowMajorBit),
48    CoeffReadCost = Dynamic
49  };
50};
51
52enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor};
53template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode>
54class sparse_diagonal_product_inner_iterator_selector;
55
56} // end namespace internal
57
58template<typename Lhs, typename Rhs>
59class SparseDiagonalProduct
60  : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >,
61    internal::no_assignment_operator
62{
63    typedef typename Lhs::Nested LhsNested;
64    typedef typename Rhs::Nested RhsNested;
65
66    typedef typename internal::remove_all<LhsNested>::type _LhsNested;
67    typedef typename internal::remove_all<RhsNested>::type _RhsNested;
68
69    enum {
70      LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal
71              : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor,
72      RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal
73              : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor
74    };
75
76  public:
77
78    EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct)
79
80    typedef internal::sparse_diagonal_product_inner_iterator_selector
81                      <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator;
82
83    // We do not want ReverseInnerIterator for diagonal-sparse products,
84    // but this dummy declaration is needed to make diag * sparse * diag compile.
85    class ReverseInnerIterator;
86
87    EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs)
88      : m_lhs(lhs), m_rhs(rhs)
89    {
90      eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product");
91    }
92
93    EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); }
94    EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); }
95
96    EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
97    EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
98
99  protected:
100    LhsNested m_lhs;
101    RhsNested m_rhs;
102};
103
104namespace internal {
105
106template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
107class sparse_diagonal_product_inner_iterator_selector
108<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
109  : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator
110{
111    typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base;
112    typedef typename Lhs::Index Index;
113  public:
114    inline sparse_diagonal_product_inner_iterator_selector(
115              const SparseDiagonalProductType& expr, Index outer)
116      : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer)
117    {}
118};
119
120template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
121class sparse_diagonal_product_inner_iterator_selector
122<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor>
123  : public CwiseBinaryOp<
124      scalar_product_op<typename Lhs::Scalar>,
125      const typename Rhs::ConstInnerVectorReturnType,
126      const typename Lhs::DiagonalVectorType>::InnerIterator
127{
128    typedef typename CwiseBinaryOp<
129      scalar_product_op<typename Lhs::Scalar>,
130      const typename Rhs::ConstInnerVectorReturnType,
131      const typename Lhs::DiagonalVectorType>::InnerIterator Base;
132    typedef typename Lhs::Index Index;
133    Index m_outer;
134  public:
135    inline sparse_diagonal_product_inner_iterator_selector(
136              const SparseDiagonalProductType& expr, Index outer)
137      : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0), m_outer(outer)
138    {}
139
140    inline Index outer() const { return m_outer; }
141    inline Index col() const { return m_outer; }
142};
143
144template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
145class sparse_diagonal_product_inner_iterator_selector
146<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal>
147  : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator
148{
149    typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base;
150    typedef typename Lhs::Index Index;
151  public:
152    inline sparse_diagonal_product_inner_iterator_selector(
153              const SparseDiagonalProductType& expr, Index outer)
154      : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer)
155    {}
156};
157
158template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
159class sparse_diagonal_product_inner_iterator_selector
160<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal>
161  : public CwiseBinaryOp<
162      scalar_product_op<typename Rhs::Scalar>,
163      const typename Lhs::ConstInnerVectorReturnType,
164      const Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator
165{
166    typedef typename CwiseBinaryOp<
167      scalar_product_op<typename Rhs::Scalar>,
168      const typename Lhs::ConstInnerVectorReturnType,
169      const Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base;
170    typedef typename Lhs::Index Index;
171    Index m_outer;
172  public:
173    inline sparse_diagonal_product_inner_iterator_selector(
174              const SparseDiagonalProductType& expr, Index outer)
175      : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0), m_outer(outer)
176    {}
177
178    inline Index outer() const { return m_outer; }
179    inline Index row() const { return m_outer; }
180};
181
182} // end namespace internal
183
184// SparseMatrixBase functions
185
186template<typename Derived>
187template<typename OtherDerived>
188const SparseDiagonalProduct<Derived,OtherDerived>
189SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const
190{
191  return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
192}
193
194} // end namespace Eigen
195
196#endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
197