1// This file is part of Eigen, a lightweight C++ template library 2// for linear algebra. 3// 4// Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr> 5// 6// This Source Code Form is subject to the terms of the Mozilla 7// Public License v. 2.0. If a copy of the MPL was not distributed 8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10#ifndef EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H 11#define EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H 12 13namespace Eigen { 14 15namespace internal { 16 17template<typename Lhs, typename Rhs, typename ResultType> 18static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res) 19{ 20 typedef typename remove_all<Lhs>::type::Scalar Scalar; 21 typedef typename remove_all<Lhs>::type::Index Index; 22 23 // make sure to call innerSize/outerSize since we fake the storage order. 24 Index rows = lhs.innerSize(); 25 Index cols = rhs.outerSize(); 26 eigen_assert(lhs.outerSize() == rhs.innerSize()); 27 28 std::vector<bool> mask(rows,false); 29 Matrix<Scalar,Dynamic,1> values(rows); 30 Matrix<Index,Dynamic,1> indices(rows); 31 32 // estimate the number of non zero entries 33 // given a rhs column containing Y non zeros, we assume that the respective Y columns 34 // of the lhs differs in average of one non zeros, thus the number of non zeros for 35 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 36 // per column of the lhs. 37 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 38 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); 39 40 res.setZero(); 41 res.reserve(Index(estimated_nnz_prod)); 42 // we compute each column of the result, one after the other 43 for (Index j=0; j<cols; ++j) 44 { 45 46 res.startVec(j); 47 Index nnz = 0; 48 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) 49 { 50 Scalar y = rhsIt.value(); 51 Index k = rhsIt.index(); 52 for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt) 53 { 54 Index i = lhsIt.index(); 55 Scalar x = lhsIt.value(); 56 if(!mask[i]) 57 { 58 mask[i] = true; 59 values[i] = x * y; 60 indices[nnz] = i; 61 ++nnz; 62 } 63 else 64 values[i] += x * y; 65 } 66 } 67 68 // unordered insertion 69 for(Index k=0; k<nnz; ++k) 70 { 71 Index i = indices[k]; 72 res.insertBackByOuterInnerUnordered(j,i) = values[i]; 73 mask[i] = false; 74 } 75 76#if 0 77 // alternative ordered insertion code: 78 79 Index t200 = rows/(log2(200)*1.39); 80 Index t = (rows*100)/139; 81 82 // FIXME reserve nnz non zeros 83 // FIXME implement fast sort algorithms for very small nnz 84 // if the result is sparse enough => use a quick sort 85 // otherwise => loop through the entire vector 86 // In order to avoid to perform an expensive log2 when the 87 // result is clearly very sparse we use a linear bound up to 200. 88 //if((nnz<200 && nnz<t200) || nnz * log2(nnz) < t) 89 //res.startVec(j); 90 if(true) 91 { 92 if(nnz>1) std::sort(indices.data(),indices.data()+nnz); 93 for(Index k=0; k<nnz; ++k) 94 { 95 Index i = indices[k]; 96 res.insertBackByOuterInner(j,i) = values[i]; 97 mask[i] = false; 98 } 99 } 100 else 101 { 102 // dense path 103 for(Index i=0; i<rows; ++i) 104 { 105 if(mask[i]) 106 { 107 mask[i] = false; 108 res.insertBackByOuterInner(j,i) = values[i]; 109 } 110 } 111 } 112#endif 113 114 } 115 res.finalize(); 116} 117 118 119} // end namespace internal 120 121namespace internal { 122 123template<typename Lhs, typename Rhs, typename ResultType, 124 int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor, 125 int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor, 126 int ResStorageOrder = (traits<ResultType>::Flags&RowMajorBit) ? RowMajor : ColMajor> 127struct conservative_sparse_sparse_product_selector; 128 129template<typename Lhs, typename Rhs, typename ResultType> 130struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 131{ 132 typedef typename remove_all<Lhs>::type LhsCleaned; 133 typedef typename LhsCleaned::Scalar Scalar; 134 135 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 136 { 137 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix; 138 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix; 139 ColMajorMatrix resCol(lhs.rows(),rhs.cols()); 140 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol); 141 // sort the non zeros: 142 RowMajorMatrix resRow(resCol); 143 res = resRow; 144 } 145}; 146 147template<typename Lhs, typename Rhs, typename ResultType> 148struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> 149{ 150 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 151 { 152 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix; 153 RowMajorMatrix rhsRow = rhs; 154 RowMajorMatrix resRow(lhs.rows(), rhs.cols()); 155 internal::conservative_sparse_sparse_product_impl<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow); 156 res = resRow; 157 } 158}; 159 160template<typename Lhs, typename Rhs, typename ResultType> 161struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> 162{ 163 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 164 { 165 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix; 166 RowMajorMatrix lhsRow = lhs; 167 RowMajorMatrix resRow(lhs.rows(), rhs.cols()); 168 internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow); 169 res = resRow; 170 } 171}; 172 173template<typename Lhs, typename Rhs, typename ResultType> 174struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 175{ 176 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 177 { 178 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix; 179 RowMajorMatrix resRow(lhs.rows(), rhs.cols()); 180 internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow); 181 res = resRow; 182 } 183}; 184 185 186template<typename Lhs, typename Rhs, typename ResultType> 187struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 188{ 189 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; 190 191 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 192 { 193 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix; 194 ColMajorMatrix resCol(lhs.rows(), rhs.cols()); 195 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol); 196 res = resCol; 197 } 198}; 199 200template<typename Lhs, typename Rhs, typename ResultType> 201struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> 202{ 203 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 204 { 205 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix; 206 ColMajorMatrix lhsCol = lhs; 207 ColMajorMatrix resCol(lhs.rows(), rhs.cols()); 208 internal::conservative_sparse_sparse_product_impl<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol); 209 res = resCol; 210 } 211}; 212 213template<typename Lhs, typename Rhs, typename ResultType> 214struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> 215{ 216 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 217 { 218 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix; 219 ColMajorMatrix rhsCol = rhs; 220 ColMajorMatrix resCol(lhs.rows(), rhs.cols()); 221 internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol); 222 res = resCol; 223 } 224}; 225 226template<typename Lhs, typename Rhs, typename ResultType> 227struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 228{ 229 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 230 { 231 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix; 232 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix; 233 RowMajorMatrix resRow(lhs.rows(),rhs.cols()); 234 internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow); 235 // sort the non zeros: 236 ColMajorMatrix resCol(resRow); 237 res = resCol; 238 } 239}; 240 241} // end namespace internal 242 243} // end namespace Eigen 244 245#endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H 246