1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_ARRAYWRAPPER_H
11#define EIGEN_ARRAYWRAPPER_H
12
13namespace Eigen {
14
15/** \class ArrayWrapper
16  * \ingroup Core_Module
17  *
18  * \brief Expression of a mathematical vector or matrix as an array object
19  *
20  * This class is the return type of MatrixBase::array(), and most of the time
21  * this is the only way it is use.
22  *
23  * \sa MatrixBase::array(), class MatrixWrapper
24  */
25
26namespace internal {
27template<typename ExpressionType>
28struct traits<ArrayWrapper<ExpressionType> >
29  : public traits<typename remove_all<typename ExpressionType::Nested>::type >
30{
31  typedef ArrayXpr XprKind;
32  // Let's remove NestByRefBit
33  enum {
34    Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags,
35    Flags = Flags0 & ~NestByRefBit
36  };
37};
38}
39
40template<typename ExpressionType>
41class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
42{
43  public:
44    typedef ArrayBase<ArrayWrapper> Base;
45    EIGEN_DENSE_PUBLIC_INTERFACE(ArrayWrapper)
46    EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ArrayWrapper)
47
48    typedef typename internal::conditional<
49                       internal::is_lvalue<ExpressionType>::value,
50                       Scalar,
51                       const Scalar
52                     >::type ScalarWithConstIfNotLvalue;
53
54    typedef typename internal::nested<ExpressionType>::type NestedExpressionType;
55
56    inline ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {}
57
58    inline Index rows() const { return m_expression.rows(); }
59    inline Index cols() const { return m_expression.cols(); }
60    inline Index outerStride() const { return m_expression.outerStride(); }
61    inline Index innerStride() const { return m_expression.innerStride(); }
62
63    inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); }
64    inline const Scalar* data() const { return m_expression.data(); }
65
66    inline CoeffReturnType coeff(Index rowId, Index colId) const
67    {
68      return m_expression.coeff(rowId, colId);
69    }
70
71    inline Scalar& coeffRef(Index rowId, Index colId)
72    {
73      return m_expression.const_cast_derived().coeffRef(rowId, colId);
74    }
75
76    inline const Scalar& coeffRef(Index rowId, Index colId) const
77    {
78      return m_expression.const_cast_derived().coeffRef(rowId, colId);
79    }
80
81    inline CoeffReturnType coeff(Index index) const
82    {
83      return m_expression.coeff(index);
84    }
85
86    inline Scalar& coeffRef(Index index)
87    {
88      return m_expression.const_cast_derived().coeffRef(index);
89    }
90
91    inline const Scalar& coeffRef(Index index) const
92    {
93      return m_expression.const_cast_derived().coeffRef(index);
94    }
95
96    template<int LoadMode>
97    inline const PacketScalar packet(Index rowId, Index colId) const
98    {
99      return m_expression.template packet<LoadMode>(rowId, colId);
100    }
101
102    template<int LoadMode>
103    inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
104    {
105      m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val);
106    }
107
108    template<int LoadMode>
109    inline const PacketScalar packet(Index index) const
110    {
111      return m_expression.template packet<LoadMode>(index);
112    }
113
114    template<int LoadMode>
115    inline void writePacket(Index index, const PacketScalar& val)
116    {
117      m_expression.const_cast_derived().template writePacket<LoadMode>(index, val);
118    }
119
120    template<typename Dest>
121    inline void evalTo(Dest& dst) const { dst = m_expression; }
122
123    const typename internal::remove_all<NestedExpressionType>::type&
124    nestedExpression() const
125    {
126      return m_expression;
127    }
128
129    /** Forwards the resizing request to the nested expression
130      * \sa DenseBase::resize(Index)  */
131    void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); }
132    /** Forwards the resizing request to the nested expression
133      * \sa DenseBase::resize(Index,Index)*/
134    void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); }
135
136  protected:
137    NestedExpressionType m_expression;
138};
139
140/** \class MatrixWrapper
141  * \ingroup Core_Module
142  *
143  * \brief Expression of an array as a mathematical vector or matrix
144  *
145  * This class is the return type of ArrayBase::matrix(), and most of the time
146  * this is the only way it is use.
147  *
148  * \sa MatrixBase::matrix(), class ArrayWrapper
149  */
150
151namespace internal {
152template<typename ExpressionType>
153struct traits<MatrixWrapper<ExpressionType> >
154 : public traits<typename remove_all<typename ExpressionType::Nested>::type >
155{
156  typedef MatrixXpr XprKind;
157  // Let's remove NestByRefBit
158  enum {
159    Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags,
160    Flags = Flags0 & ~NestByRefBit
161  };
162};
163}
164
165template<typename ExpressionType>
166class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
167{
168  public:
169    typedef MatrixBase<MatrixWrapper<ExpressionType> > Base;
170    EIGEN_DENSE_PUBLIC_INTERFACE(MatrixWrapper)
171    EIGEN_INHERIT_ASSIGNMENT_OPERATORS(MatrixWrapper)
172
173    typedef typename internal::conditional<
174                       internal::is_lvalue<ExpressionType>::value,
175                       Scalar,
176                       const Scalar
177                     >::type ScalarWithConstIfNotLvalue;
178
179    typedef typename internal::nested<ExpressionType>::type NestedExpressionType;
180
181    inline MatrixWrapper(ExpressionType& a_matrix) : m_expression(a_matrix) {}
182
183    inline Index rows() const { return m_expression.rows(); }
184    inline Index cols() const { return m_expression.cols(); }
185    inline Index outerStride() const { return m_expression.outerStride(); }
186    inline Index innerStride() const { return m_expression.innerStride(); }
187
188    inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); }
189    inline const Scalar* data() const { return m_expression.data(); }
190
191    inline CoeffReturnType coeff(Index rowId, Index colId) const
192    {
193      return m_expression.coeff(rowId, colId);
194    }
195
196    inline Scalar& coeffRef(Index rowId, Index colId)
197    {
198      return m_expression.const_cast_derived().coeffRef(rowId, colId);
199    }
200
201    inline const Scalar& coeffRef(Index rowId, Index colId) const
202    {
203      return m_expression.derived().coeffRef(rowId, colId);
204    }
205
206    inline CoeffReturnType coeff(Index index) const
207    {
208      return m_expression.coeff(index);
209    }
210
211    inline Scalar& coeffRef(Index index)
212    {
213      return m_expression.const_cast_derived().coeffRef(index);
214    }
215
216    inline const Scalar& coeffRef(Index index) const
217    {
218      return m_expression.const_cast_derived().coeffRef(index);
219    }
220
221    template<int LoadMode>
222    inline const PacketScalar packet(Index rowId, Index colId) const
223    {
224      return m_expression.template packet<LoadMode>(rowId, colId);
225    }
226
227    template<int LoadMode>
228    inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
229    {
230      m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val);
231    }
232
233    template<int LoadMode>
234    inline const PacketScalar packet(Index index) const
235    {
236      return m_expression.template packet<LoadMode>(index);
237    }
238
239    template<int LoadMode>
240    inline void writePacket(Index index, const PacketScalar& val)
241    {
242      m_expression.const_cast_derived().template writePacket<LoadMode>(index, val);
243    }
244
245    const typename internal::remove_all<NestedExpressionType>::type&
246    nestedExpression() const
247    {
248      return m_expression;
249    }
250
251    /** Forwards the resizing request to the nested expression
252      * \sa DenseBase::resize(Index)  */
253    void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); }
254    /** Forwards the resizing request to the nested expression
255      * \sa DenseBase::resize(Index,Index)*/
256    void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); }
257
258  protected:
259    NestedExpressionType m_expression;
260};
261
262} // end namespace Eigen
263
264#endif // EIGEN_ARRAYWRAPPER_H
265