nullary_indexing.cpp revision 2b8756b6f1de65d3f8bffab45be6c44ceb7411fc
1#include <Eigen/Core> 2#include <iostream> 3 4using namespace Eigen; 5 6// [functor] 7template<class ArgType, class RowIndexType, class ColIndexType> 8class indexing_functor { 9 const ArgType &m_arg; 10 const RowIndexType &m_rowIndices; 11 const ColIndexType &m_colIndices; 12public: 13 typedef Matrix<typename ArgType::Scalar, 14 RowIndexType::SizeAtCompileTime, 15 ColIndexType::SizeAtCompileTime, 16 ArgType::Flags&RowMajorBit?RowMajor:ColMajor, 17 RowIndexType::MaxSizeAtCompileTime, 18 ColIndexType::MaxSizeAtCompileTime> MatrixType; 19 20 indexing_functor(const ArgType& arg, const RowIndexType& row_indices, const ColIndexType& col_indices) 21 : m_arg(arg), m_rowIndices(row_indices), m_colIndices(col_indices) 22 {} 23 24 const typename ArgType::Scalar& operator() (Index row, Index col) const { 25 return m_arg(m_rowIndices[row], m_colIndices[col]); 26 } 27}; 28// [functor] 29 30// [function] 31template <class ArgType, class RowIndexType, class ColIndexType> 32CwiseNullaryOp<indexing_functor<ArgType,RowIndexType,ColIndexType>, typename indexing_functor<ArgType,RowIndexType,ColIndexType>::MatrixType> 33indexing(const Eigen::MatrixBase<ArgType>& arg, const RowIndexType& row_indices, const ColIndexType& col_indices) 34{ 35 typedef indexing_functor<ArgType,RowIndexType,ColIndexType> Func; 36 typedef typename Func::MatrixType MatrixType; 37 return MatrixType::NullaryExpr(row_indices.size(), col_indices.size(), Func(arg.derived(), row_indices, col_indices)); 38} 39// [function] 40 41 42int main() 43{ 44 std::cout << "[main1]\n"; 45 Eigen::MatrixXi A = Eigen::MatrixXi::Random(4,4); 46 Array3i ri(1,2,1); 47 ArrayXi ci(6); ci << 3,2,1,0,0,2; 48 Eigen::MatrixXi B = indexing(A, ri, ci); 49 std::cout << "A =" << std::endl; 50 std::cout << A << std::endl << std::endl; 51 std::cout << "A([" << ri.transpose() << "], [" << ci.transpose() << "]) =" << std::endl; 52 std::cout << B << std::endl; 53 std::cout << "[main1]\n"; 54 55 std::cout << "[main2]\n"; 56 B = indexing(A, ri+1, ci); 57 std::cout << "A(ri+1,ci) =" << std::endl; 58 std::cout << B << std::endl << std::endl; 59#if __cplusplus >= 201103L 60 B = indexing(A, ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3)); 61 std::cout << "A(ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3)) =" << std::endl; 62 std::cout << B << std::endl << std::endl; 63#endif 64 std::cout << "[main2]\n"; 65} 66 67