1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2009 Guillaume Saupin <guillaume.saupin@cea.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SKYLINEPRODUCT_H
11#define EIGEN_SKYLINEPRODUCT_H
12
13namespace Eigen {
14
15template<typename Lhs, typename Rhs, int ProductMode>
16struct SkylineProductReturnType {
17    typedef const typename internal::nested<Lhs, Rhs::RowsAtCompileTime>::type LhsNested;
18    typedef const typename internal::nested<Rhs, Lhs::RowsAtCompileTime>::type RhsNested;
19
20    typedef SkylineProduct<LhsNested, RhsNested, ProductMode> Type;
21};
22
23template<typename LhsNested, typename RhsNested, int ProductMode>
24struct internal::traits<SkylineProduct<LhsNested, RhsNested, ProductMode> > {
25    // clean the nested types:
26    typedef typename internal::remove_all<LhsNested>::type _LhsNested;
27    typedef typename internal::remove_all<RhsNested>::type _RhsNested;
28    typedef typename _LhsNested::Scalar Scalar;
29
30    enum {
31        LhsCoeffReadCost = _LhsNested::CoeffReadCost,
32        RhsCoeffReadCost = _RhsNested::CoeffReadCost,
33        LhsFlags = _LhsNested::Flags,
34        RhsFlags = _RhsNested::Flags,
35
36        RowsAtCompileTime = _LhsNested::RowsAtCompileTime,
37        ColsAtCompileTime = _RhsNested::ColsAtCompileTime,
38        InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(_LhsNested::ColsAtCompileTime, _RhsNested::RowsAtCompileTime),
39
40        MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
41        MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
42
43        EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
44        ResultIsSkyline = ProductMode == SkylineTimeSkylineProduct,
45
46        RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSkyline ? 0 : SkylineBit)),
47
48        Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
49        | EvalBeforeAssigningBit
50        | EvalBeforeNestingBit,
51
52        CoeffReadCost = Dynamic
53    };
54
55    typedef typename internal::conditional<ResultIsSkyline,
56            SkylineMatrixBase<SkylineProduct<LhsNested, RhsNested, ProductMode> >,
57            MatrixBase<SkylineProduct<LhsNested, RhsNested, ProductMode> > >::type Base;
58};
59
60namespace internal {
61template<typename LhsNested, typename RhsNested, int ProductMode>
62class SkylineProduct : no_assignment_operator,
63public traits<SkylineProduct<LhsNested, RhsNested, ProductMode> >::Base {
64public:
65
66    EIGEN_GENERIC_PUBLIC_INTERFACE(SkylineProduct)
67
68private:
69
70    typedef typename traits<SkylineProduct>::_LhsNested _LhsNested;
71    typedef typename traits<SkylineProduct>::_RhsNested _RhsNested;
72
73public:
74
75    template<typename Lhs, typename Rhs>
76    EIGEN_STRONG_INLINE SkylineProduct(const Lhs& lhs, const Rhs& rhs)
77    : m_lhs(lhs), m_rhs(rhs) {
78        eigen_assert(lhs.cols() == rhs.rows());
79
80        enum {
81            ProductIsValid = _LhsNested::ColsAtCompileTime == Dynamic
82            || _RhsNested::RowsAtCompileTime == Dynamic
83            || int(_LhsNested::ColsAtCompileTime) == int(_RhsNested::RowsAtCompileTime),
84            AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
85            SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested, _RhsNested)
86        };
87        // note to the lost user:
88        //    * for a dot product use: v1.dot(v2)
89        //    * for a coeff-wise product use: v1.cwise()*v2
90        EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
91                INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
92                EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
93                INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
94                EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
95    }
96
97    EIGEN_STRONG_INLINE Index rows() const {
98        return m_lhs.rows();
99    }
100
101    EIGEN_STRONG_INLINE Index cols() const {
102        return m_rhs.cols();
103    }
104
105    EIGEN_STRONG_INLINE const _LhsNested& lhs() const {
106        return m_lhs;
107    }
108
109    EIGEN_STRONG_INLINE const _RhsNested& rhs() const {
110        return m_rhs;
111    }
112
113protected:
114    LhsNested m_lhs;
115    RhsNested m_rhs;
116};
117
118// dense = skyline * dense
119// Note that here we force no inlining and separate the setZero() because GCC messes up otherwise
120
121template<typename Lhs, typename Rhs, typename Dest>
122EIGEN_DONT_INLINE void skyline_row_major_time_dense_product(const Lhs& lhs, const Rhs& rhs, Dest& dst) {
123    typedef typename remove_all<Lhs>::type _Lhs;
124    typedef typename remove_all<Rhs>::type _Rhs;
125    typedef typename traits<Lhs>::Scalar Scalar;
126
127    enum {
128        LhsIsRowMajor = (_Lhs::Flags & RowMajorBit) == RowMajorBit,
129        LhsIsSelfAdjoint = (_Lhs::Flags & SelfAdjointBit) == SelfAdjointBit,
130        ProcessFirstHalf = LhsIsSelfAdjoint
131        && (((_Lhs::Flags & (UpperTriangularBit | LowerTriangularBit)) == 0)
132        || ((_Lhs::Flags & UpperTriangularBit) && !LhsIsRowMajor)
133        || ((_Lhs::Flags & LowerTriangularBit) && LhsIsRowMajor)),
134        ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
135    };
136
137    //Use matrix diagonal part <- Improvement : use inner iterator on dense matrix.
138    for (Index col = 0; col < rhs.cols(); col++) {
139        for (Index row = 0; row < lhs.rows(); row++) {
140            dst(row, col) = lhs.coeffDiag(row) * rhs(row, col);
141        }
142    }
143    //Use matrix lower triangular part
144    for (Index row = 0; row < lhs.rows(); row++) {
145        typename _Lhs::InnerLowerIterator lIt(lhs, row);
146        const Index stop = lIt.col() + lIt.size();
147        for (Index col = 0; col < rhs.cols(); col++) {
148
149            Index k = lIt.col();
150            Scalar tmp = 0;
151            while (k < stop) {
152                tmp +=
153                        lIt.value() *
154                        rhs(k++, col);
155                ++lIt;
156            }
157            dst(row, col) += tmp;
158            lIt += -lIt.size();
159        }
160
161    }
162
163    //Use matrix upper triangular part
164    for (Index lhscol = 0; lhscol < lhs.cols(); lhscol++) {
165        typename _Lhs::InnerUpperIterator uIt(lhs, lhscol);
166        const Index stop = uIt.size() + uIt.row();
167        for (Index rhscol = 0; rhscol < rhs.cols(); rhscol++) {
168
169
170            const Scalar rhsCoeff = rhs.coeff(lhscol, rhscol);
171            Index k = uIt.row();
172            while (k < stop) {
173                dst(k++, rhscol) +=
174                        uIt.value() *
175                        rhsCoeff;
176                ++uIt;
177            }
178            uIt += -uIt.size();
179        }
180    }
181
182}
183
184template<typename Lhs, typename Rhs, typename Dest>
185EIGEN_DONT_INLINE void skyline_col_major_time_dense_product(const Lhs& lhs, const Rhs& rhs, Dest& dst) {
186    typedef typename remove_all<Lhs>::type _Lhs;
187    typedef typename remove_all<Rhs>::type _Rhs;
188    typedef typename traits<Lhs>::Scalar Scalar;
189
190    enum {
191        LhsIsRowMajor = (_Lhs::Flags & RowMajorBit) == RowMajorBit,
192        LhsIsSelfAdjoint = (_Lhs::Flags & SelfAdjointBit) == SelfAdjointBit,
193        ProcessFirstHalf = LhsIsSelfAdjoint
194        && (((_Lhs::Flags & (UpperTriangularBit | LowerTriangularBit)) == 0)
195        || ((_Lhs::Flags & UpperTriangularBit) && !LhsIsRowMajor)
196        || ((_Lhs::Flags & LowerTriangularBit) && LhsIsRowMajor)),
197        ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
198    };
199
200    //Use matrix diagonal part <- Improvement : use inner iterator on dense matrix.
201    for (Index col = 0; col < rhs.cols(); col++) {
202        for (Index row = 0; row < lhs.rows(); row++) {
203            dst(row, col) = lhs.coeffDiag(row) * rhs(row, col);
204        }
205    }
206
207    //Use matrix upper triangular part
208    for (Index row = 0; row < lhs.rows(); row++) {
209        typename _Lhs::InnerUpperIterator uIt(lhs, row);
210        const Index stop = uIt.col() + uIt.size();
211        for (Index col = 0; col < rhs.cols(); col++) {
212
213            Index k = uIt.col();
214            Scalar tmp = 0;
215            while (k < stop) {
216                tmp +=
217                        uIt.value() *
218                        rhs(k++, col);
219                ++uIt;
220            }
221
222
223            dst(row, col) += tmp;
224            uIt += -uIt.size();
225        }
226    }
227
228    //Use matrix lower triangular part
229    for (Index lhscol = 0; lhscol < lhs.cols(); lhscol++) {
230        typename _Lhs::InnerLowerIterator lIt(lhs, lhscol);
231        const Index stop = lIt.size() + lIt.row();
232        for (Index rhscol = 0; rhscol < rhs.cols(); rhscol++) {
233
234            const Scalar rhsCoeff = rhs.coeff(lhscol, rhscol);
235            Index k = lIt.row();
236            while (k < stop) {
237                dst(k++, rhscol) +=
238                        lIt.value() *
239                        rhsCoeff;
240                ++lIt;
241            }
242            lIt += -lIt.size();
243        }
244    }
245
246}
247
248template<typename Lhs, typename Rhs, typename ResultType,
249        int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit>
250        struct skyline_product_selector;
251
252template<typename Lhs, typename Rhs, typename ResultType>
253struct skyline_product_selector<Lhs, Rhs, ResultType, RowMajor> {
254    typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
255
256    static void run(const Lhs& lhs, const Rhs& rhs, ResultType & res) {
257        skyline_row_major_time_dense_product<Lhs, Rhs, ResultType > (lhs, rhs, res);
258    }
259};
260
261template<typename Lhs, typename Rhs, typename ResultType>
262struct skyline_product_selector<Lhs, Rhs, ResultType, ColMajor> {
263    typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
264
265    static void run(const Lhs& lhs, const Rhs& rhs, ResultType & res) {
266        skyline_col_major_time_dense_product<Lhs, Rhs, ResultType > (lhs, rhs, res);
267    }
268};
269
270} // end namespace internal
271
272// template<typename Derived>
273// template<typename Lhs, typename Rhs >
274// Derived & MatrixBase<Derived>::lazyAssign(const SkylineProduct<Lhs, Rhs, SkylineTimeDenseProduct>& product) {
275//     typedef typename internal::remove_all<Lhs>::type _Lhs;
276//     internal::skyline_product_selector<typename internal::remove_all<Lhs>::type,
277//             typename internal::remove_all<Rhs>::type,
278//             Derived>::run(product.lhs(), product.rhs(), derived());
279//
280//     return derived();
281// }
282
283// skyline * dense
284
285template<typename Derived>
286template<typename OtherDerived >
287EIGEN_STRONG_INLINE const typename SkylineProductReturnType<Derived, OtherDerived>::Type
288SkylineMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const {
289
290    return typename SkylineProductReturnType<Derived, OtherDerived>::Type(derived(), other.derived());
291}
292
293} // end namespace Eigen
294
295#endif // EIGEN_SKYLINEPRODUCT_H
296