SparseDiagonalProduct.h revision 7faaa9f3f0df9d23790277834d426c3d992ac3ba
1// This file is part of Eigen, a lightweight C++ template library 2// for linear algebra. 3// 4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> 5// 6// This Source Code Form is subject to the terms of the Mozilla 7// Public License v. 2.0. If a copy of the MPL was not distributed 8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10#ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H 11#define EIGEN_SPARSE_DIAGONAL_PRODUCT_H 12 13namespace Eigen { 14 15// The product of a diagonal matrix with a sparse matrix can be easily 16// implemented using expression template. 17// We have two consider very different cases: 18// 1 - diag * row-major sparse 19// => each inner vector <=> scalar * sparse vector product 20// => so we can reuse CwiseUnaryOp::InnerIterator 21// 2 - diag * col-major sparse 22// => each inner vector <=> densevector * sparse vector cwise product 23// => again, we can reuse specialization of CwiseBinaryOp::InnerIterator 24// for that particular case 25// The two other cases are symmetric. 26 27namespace internal { 28 29template<typename Lhs, typename Rhs> 30struct traits<SparseDiagonalProduct<Lhs, Rhs> > 31{ 32 typedef typename remove_all<Lhs>::type _Lhs; 33 typedef typename remove_all<Rhs>::type _Rhs; 34 typedef typename _Lhs::Scalar Scalar; 35 typedef typename promote_index_type<typename traits<Lhs>::Index, 36 typename traits<Rhs>::Index>::type Index; 37 typedef Sparse StorageKind; 38 typedef MatrixXpr XprKind; 39 enum { 40 RowsAtCompileTime = _Lhs::RowsAtCompileTime, 41 ColsAtCompileTime = _Rhs::ColsAtCompileTime, 42 43 MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime, 44 MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime, 45 46 SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags), 47 Flags = (SparseFlags&RowMajorBit), 48 CoeffReadCost = Dynamic 49 }; 50}; 51 52enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor}; 53template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode> 54class sparse_diagonal_product_inner_iterator_selector; 55 56} // end namespace internal 57 58template<typename Lhs, typename Rhs> 59class SparseDiagonalProduct 60 : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >, 61 internal::no_assignment_operator 62{ 63 typedef typename Lhs::Nested LhsNested; 64 typedef typename Rhs::Nested RhsNested; 65 66 typedef typename internal::remove_all<LhsNested>::type _LhsNested; 67 typedef typename internal::remove_all<RhsNested>::type _RhsNested; 68 69 enum { 70 LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal 71 : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor, 72 RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal 73 : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor 74 }; 75 76 public: 77 78 EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct) 79 80 typedef internal::sparse_diagonal_product_inner_iterator_selector 81 <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator; 82 83 // We do not want ReverseInnerIterator for diagonal-sparse products, 84 // but this dummy declaration is needed to make diag * sparse * diag compile. 85 class ReverseInnerIterator; 86 87 EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs) 88 : m_lhs(lhs), m_rhs(rhs) 89 { 90 eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product"); 91 } 92 93 EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); } 94 EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); } 95 96 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } 97 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } 98 99 protected: 100 LhsNested m_lhs; 101 RhsNested m_rhs; 102}; 103 104namespace internal { 105 106template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 107class sparse_diagonal_product_inner_iterator_selector 108<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor> 109 : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator 110{ 111 typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base; 112 typedef typename Lhs::Index Index; 113 public: 114 inline sparse_diagonal_product_inner_iterator_selector( 115 const SparseDiagonalProductType& expr, Index outer) 116 : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer) 117 {} 118}; 119 120template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 121class sparse_diagonal_product_inner_iterator_selector 122<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor> 123 : public CwiseBinaryOp< 124 scalar_product_op<typename Lhs::Scalar>, 125 const typename Rhs::ConstInnerVectorReturnType, 126 const typename Lhs::DiagonalVectorType>::InnerIterator 127{ 128 typedef typename CwiseBinaryOp< 129 scalar_product_op<typename Lhs::Scalar>, 130 const typename Rhs::ConstInnerVectorReturnType, 131 const typename Lhs::DiagonalVectorType>::InnerIterator Base; 132 typedef typename Lhs::Index Index; 133 Index m_outer; 134 public: 135 inline sparse_diagonal_product_inner_iterator_selector( 136 const SparseDiagonalProductType& expr, Index outer) 137 : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0), m_outer(outer) 138 {} 139 140 inline Index outer() const { return m_outer; } 141 inline Index col() const { return m_outer; } 142}; 143 144template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 145class sparse_diagonal_product_inner_iterator_selector 146<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal> 147 : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator 148{ 149 typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base; 150 typedef typename Lhs::Index Index; 151 public: 152 inline sparse_diagonal_product_inner_iterator_selector( 153 const SparseDiagonalProductType& expr, Index outer) 154 : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer) 155 {} 156}; 157 158template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 159class sparse_diagonal_product_inner_iterator_selector 160<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal> 161 : public CwiseBinaryOp< 162 scalar_product_op<typename Rhs::Scalar>, 163 const typename Lhs::ConstInnerVectorReturnType, 164 const Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator 165{ 166 typedef typename CwiseBinaryOp< 167 scalar_product_op<typename Rhs::Scalar>, 168 const typename Lhs::ConstInnerVectorReturnType, 169 const Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base; 170 typedef typename Lhs::Index Index; 171 Index m_outer; 172 public: 173 inline sparse_diagonal_product_inner_iterator_selector( 174 const SparseDiagonalProductType& expr, Index outer) 175 : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0), m_outer(outer) 176 {} 177 178 inline Index outer() const { return m_outer; } 179 inline Index row() const { return m_outer; } 180}; 181 182} // end namespace internal 183 184// SparseMatrixBase functions 185 186template<typename Derived> 187template<typename OtherDerived> 188const SparseDiagonalProduct<Derived,OtherDerived> 189SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const 190{ 191 return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived()); 192} 193 194} // end namespace Eigen 195 196#endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H 197