1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11#define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12
13namespace Eigen {
14
15/** \class TensorExpr
16  * \ingroup CXX11_Tensor_Module
17  *
18  * \brief Tensor expression classes.
19  *
20  * The TensorCwiseNullaryOp class applies a nullary operators to an expression.
21  * This is typically used to generate constants.
22  *
23  * The TensorCwiseUnaryOp class represents an expression where a unary operator
24  * (e.g. cwiseSqrt) is applied to an expression.
25  *
26  * The TensorCwiseBinaryOp class represents an expression where a binary
27  * operator (e.g. addition) is applied to a lhs and a rhs expression.
28  *
29  */
30namespace internal {
31template<typename NullaryOp, typename XprType>
32struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
33    : traits<XprType>
34{
35  typedef traits<XprType> XprTraits;
36  typedef typename XprType::Scalar Scalar;
37  typedef typename XprType::Nested XprTypeNested;
38  typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
39  static const int NumDimensions = XprTraits::NumDimensions;
40  static const int Layout = XprTraits::Layout;
41
42  enum {
43    Flags = 0
44  };
45};
46
47}  // end namespace internal
48
49
50
51template<typename NullaryOp, typename XprType>
52class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
53{
54  public:
55    typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
56    typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57    typedef typename XprType::CoeffReturnType CoeffReturnType;
58    typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
59    typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
60    typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
61
62    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
63        : m_xpr(xpr), m_functor(func) {}
64
65    EIGEN_DEVICE_FUNC
66    const typename internal::remove_all<typename XprType::Nested>::type&
67    nestedExpression() const { return m_xpr; }
68
69    EIGEN_DEVICE_FUNC
70    const NullaryOp& functor() const { return m_functor; }
71
72  protected:
73    typename XprType::Nested m_xpr;
74    const NullaryOp m_functor;
75};
76
77
78
79namespace internal {
80template<typename UnaryOp, typename XprType>
81struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
82    : traits<XprType>
83{
84  // TODO(phli): Add InputScalar, InputPacket.  Check references to
85  // current Scalar/Packet to see if the intent is Input or Output.
86  typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
87  typedef traits<XprType> XprTraits;
88  typedef typename XprType::Nested XprTypeNested;
89  typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
90  static const int NumDimensions = XprTraits::NumDimensions;
91  static const int Layout = XprTraits::Layout;
92};
93
94template<typename UnaryOp, typename XprType>
95struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
96{
97  typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
98};
99
100template<typename UnaryOp, typename XprType>
101struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
102{
103  typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
104};
105
106}  // end namespace internal
107
108
109
110template<typename UnaryOp, typename XprType>
111class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
112{
113  public:
114    // TODO(phli): Add InputScalar, InputPacket.  Check references to
115    // current Scalar/Packet to see if the intent is Input or Output.
116    typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
117    typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
118    typedef Scalar CoeffReturnType;
119    typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
120    typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
121    typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
122
123    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
124      : m_xpr(xpr), m_functor(func) {}
125
126    EIGEN_DEVICE_FUNC
127    const UnaryOp& functor() const { return m_functor; }
128
129    /** \returns the nested expression */
130    EIGEN_DEVICE_FUNC
131    const typename internal::remove_all<typename XprType::Nested>::type&
132    nestedExpression() const { return m_xpr; }
133
134  protected:
135    typename XprType::Nested m_xpr;
136    const UnaryOp m_functor;
137};
138
139
140namespace internal {
141template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
142struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
143{
144  // Type promotion to handle the case where the types of the lhs and the rhs
145  // are different.
146  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
147  // current Scalar/Packet to see if the intent is Inputs or Output.
148  typedef typename result_of<
149      BinaryOp(typename LhsXprType::Scalar,
150               typename RhsXprType::Scalar)>::type Scalar;
151  typedef traits<LhsXprType> XprTraits;
152  typedef typename promote_storage_type<
153      typename traits<LhsXprType>::StorageKind,
154      typename traits<RhsXprType>::StorageKind>::ret StorageKind;
155  typedef typename promote_index_type<
156      typename traits<LhsXprType>::Index,
157      typename traits<RhsXprType>::Index>::type Index;
158  typedef typename LhsXprType::Nested LhsNested;
159  typedef typename RhsXprType::Nested RhsNested;
160  typedef typename remove_reference<LhsNested>::type _LhsNested;
161  typedef typename remove_reference<RhsNested>::type _RhsNested;
162  static const int NumDimensions = XprTraits::NumDimensions;
163  static const int Layout = XprTraits::Layout;
164
165  enum {
166    Flags = 0
167  };
168};
169
170template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
171struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
172{
173  typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
174};
175
176template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
177struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
178{
179  typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
180};
181
182}  // end namespace internal
183
184
185
186template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
187class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
188{
189  public:
190    // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
191    // current Scalar/Packet to see if the intent is Inputs or Output.
192    typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
193    typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
194    typedef Scalar CoeffReturnType;
195    typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
196    typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
197    typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
198
199    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
200        : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
201
202    EIGEN_DEVICE_FUNC
203    const BinaryOp& functor() const { return m_functor; }
204
205    /** \returns the nested expressions */
206    EIGEN_DEVICE_FUNC
207    const typename internal::remove_all<typename LhsXprType::Nested>::type&
208    lhsExpression() const { return m_lhs_xpr; }
209
210    EIGEN_DEVICE_FUNC
211    const typename internal::remove_all<typename RhsXprType::Nested>::type&
212    rhsExpression() const { return m_rhs_xpr; }
213
214  protected:
215    typename LhsXprType::Nested m_lhs_xpr;
216    typename RhsXprType::Nested m_rhs_xpr;
217    const BinaryOp m_functor;
218};
219
220
221namespace internal {
222template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
223struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
224{
225  // Type promotion to handle the case where the types of the args are different.
226  typedef typename result_of<
227      TernaryOp(typename Arg1XprType::Scalar,
228                typename Arg2XprType::Scalar,
229                typename Arg3XprType::Scalar)>::type Scalar;
230  typedef traits<Arg1XprType> XprTraits;
231  typedef typename traits<Arg1XprType>::StorageKind StorageKind;
232  typedef typename traits<Arg1XprType>::Index Index;
233  typedef typename Arg1XprType::Nested Arg1Nested;
234  typedef typename Arg2XprType::Nested Arg2Nested;
235  typedef typename Arg3XprType::Nested Arg3Nested;
236  typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
237  typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
238  typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
239  static const int NumDimensions = XprTraits::NumDimensions;
240  static const int Layout = XprTraits::Layout;
241
242  enum {
243    Flags = 0
244  };
245};
246
247template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
248struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
249{
250  typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
251};
252
253template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
254struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
255{
256  typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
257};
258
259}  // end namespace internal
260
261
262
263template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
264class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
265{
266  public:
267    typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
268    typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
269    typedef Scalar CoeffReturnType;
270    typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
271    typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
272    typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
273
274    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
275        : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
276
277    EIGEN_DEVICE_FUNC
278    const TernaryOp& functor() const { return m_functor; }
279
280    /** \returns the nested expressions */
281    EIGEN_DEVICE_FUNC
282    const typename internal::remove_all<typename Arg1XprType::Nested>::type&
283    arg1Expression() const { return m_arg1_xpr; }
284
285    EIGEN_DEVICE_FUNC
286    const typename internal::remove_all<typename Arg2XprType::Nested>::type&
287    arg2Expression() const { return m_arg2_xpr; }
288
289    EIGEN_DEVICE_FUNC
290    const typename internal::remove_all<typename Arg3XprType::Nested>::type&
291    arg3Expression() const { return m_arg3_xpr; }
292
293  protected:
294    typename Arg1XprType::Nested m_arg1_xpr;
295    typename Arg2XprType::Nested m_arg2_xpr;
296    typename Arg3XprType::Nested m_arg3_xpr;
297    const TernaryOp m_functor;
298};
299
300
301namespace internal {
302template<typename IfXprType, typename ThenXprType, typename ElseXprType>
303struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
304    : traits<ThenXprType>
305{
306  typedef typename traits<ThenXprType>::Scalar Scalar;
307  typedef traits<ThenXprType> XprTraits;
308  typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
309                                        typename traits<ElseXprType>::StorageKind>::ret StorageKind;
310  typedef typename promote_index_type<typename traits<ElseXprType>::Index,
311                                      typename traits<ThenXprType>::Index>::type Index;
312  typedef typename IfXprType::Nested IfNested;
313  typedef typename ThenXprType::Nested ThenNested;
314  typedef typename ElseXprType::Nested ElseNested;
315  static const int NumDimensions = XprTraits::NumDimensions;
316  static const int Layout = XprTraits::Layout;
317};
318
319template<typename IfXprType, typename ThenXprType, typename ElseXprType>
320struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
321{
322  typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
323};
324
325template<typename IfXprType, typename ThenXprType, typename ElseXprType>
326struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
327{
328  typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
329};
330
331}  // end namespace internal
332
333
334template<typename IfXprType, typename ThenXprType, typename ElseXprType>
335class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
336{
337  public:
338    typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
339    typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
340    typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
341                                                    typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
342    typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
343    typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
344    typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
345
346    EIGEN_DEVICE_FUNC
347    TensorSelectOp(const IfXprType& a_condition,
348                   const ThenXprType& a_then,
349                   const ElseXprType& a_else)
350      : m_condition(a_condition), m_then(a_then), m_else(a_else)
351    { }
352
353    EIGEN_DEVICE_FUNC
354    const IfXprType& ifExpression() const { return m_condition; }
355
356    EIGEN_DEVICE_FUNC
357    const ThenXprType& thenExpression() const { return m_then; }
358
359    EIGEN_DEVICE_FUNC
360    const ElseXprType& elseExpression() const { return m_else; }
361
362  protected:
363    typename IfXprType::Nested m_condition;
364    typename ThenXprType::Nested m_then;
365    typename ElseXprType::Nested m_else;
366};
367
368
369} // end namespace Eigen
370
371#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
372