1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12
13namespace Eigen {
14
15/** \class TensorCustomUnaryOp
16  * \ingroup CXX11_Tensor_Module
17  *
18  * \brief Tensor custom class.
19  *
20  *
21  */
22namespace internal {
23template<typename CustomUnaryFunc, typename XprType>
24struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
25{
26  typedef typename XprType::Scalar Scalar;
27  typedef typename XprType::StorageKind StorageKind;
28  typedef typename XprType::Index Index;
29  typedef typename XprType::Nested Nested;
30  typedef typename remove_reference<Nested>::type _Nested;
31  static const int NumDimensions = traits<XprType>::NumDimensions;
32  static const int Layout = traits<XprType>::Layout;
33};
34
35template<typename CustomUnaryFunc, typename XprType>
36struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
37{
38  typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
39};
40
41template<typename CustomUnaryFunc, typename XprType>
42struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
43{
44  typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
45};
46
47}  // end namespace internal
48
49
50
51template<typename CustomUnaryFunc, typename XprType>
52class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
53{
54  public:
55  typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
56  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57  typedef typename XprType::CoeffReturnType CoeffReturnType;
58  typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
59  typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
60  typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
61
62  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
63      : m_expr(expr), m_func(func) {}
64
65  EIGEN_DEVICE_FUNC
66  const CustomUnaryFunc& func() const { return m_func; }
67
68  EIGEN_DEVICE_FUNC
69  const typename internal::remove_all<typename XprType::Nested>::type&
70  expression() const { return m_expr; }
71
72  protected:
73    typename XprType::Nested m_expr;
74    const CustomUnaryFunc m_func;
75};
76
77
78// Eval as rvalue
79template<typename CustomUnaryFunc, typename XprType, typename Device>
80struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
81{
82  typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType;
83  typedef typename internal::traits<ArgType>::Index Index;
84  static const int NumDims = internal::traits<ArgType>::NumDimensions;
85  typedef DSizes<Index, NumDims> Dimensions;
86  typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
87  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
88  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
89  static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
90
91  enum {
92    IsAligned = false,
93    PacketAccess = (internal::packet_traits<Scalar>::size > 1),
94    BlockAccess = false,
95    Layout = TensorEvaluator<XprType, Device>::Layout,
96    CoordAccess = false,  // to be implemented
97    RawAccess = false
98  };
99
100  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
101      : m_op(op), m_device(device), m_result(NULL)
102  {
103    m_dimensions = op.func().dimensions(op.expression());
104  }
105
106  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
107
108  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
109    if (data) {
110      evalTo(data);
111      return false;
112    } else {
113      m_result = static_cast<CoeffReturnType*>(
114          m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
115      evalTo(m_result);
116      return true;
117    }
118  }
119
120  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
121    if (m_result != NULL) {
122      m_device.deallocate(m_result);
123      m_result = NULL;
124    }
125  }
126
127  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
128    return m_result[index];
129  }
130
131  template<int LoadMode>
132  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
133    return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
134  }
135
136  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
137    // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
138    return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
139  }
140
141  EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
142
143 protected:
144  EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
145    TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
146        data, m_dimensions);
147    m_op.func().eval(m_op.expression(), result, m_device);
148  }
149
150  Dimensions m_dimensions;
151  const ArgType m_op;
152  const Device& m_device;
153  CoeffReturnType* m_result;
154};
155
156
157
158/** \class TensorCustomBinaryOp
159  * \ingroup CXX11_Tensor_Module
160  *
161  * \brief Tensor custom class.
162  *
163  *
164  */
165namespace internal {
166template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
167struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
168{
169  typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
170                                                  typename RhsXprType::Scalar>::ret Scalar;
171  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
172                                                  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
173  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
174                                        typename traits<RhsXprType>::StorageKind>::ret StorageKind;
175  typedef typename promote_index_type<typename traits<LhsXprType>::Index,
176                                      typename traits<RhsXprType>::Index>::type Index;
177  typedef typename LhsXprType::Nested LhsNested;
178  typedef typename RhsXprType::Nested RhsNested;
179  typedef typename remove_reference<LhsNested>::type _LhsNested;
180  typedef typename remove_reference<RhsNested>::type _RhsNested;
181  static const int NumDimensions = traits<LhsXprType>::NumDimensions;
182  static const int Layout = traits<LhsXprType>::Layout;
183};
184
185template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
186struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
187{
188  typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
189};
190
191template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
192struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
193{
194  typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
195};
196
197}  // end namespace internal
198
199
200
201template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
202class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
203{
204  public:
205  typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
206  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
207  typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
208  typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
209  typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
210  typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
211
212  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
213
214      : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
215
216  EIGEN_DEVICE_FUNC
217  const CustomBinaryFunc& func() const { return m_func; }
218
219  EIGEN_DEVICE_FUNC
220  const typename internal::remove_all<typename LhsXprType::Nested>::type&
221  lhsExpression() const { return m_lhs_xpr; }
222
223  EIGEN_DEVICE_FUNC
224  const typename internal::remove_all<typename RhsXprType::Nested>::type&
225  rhsExpression() const { return m_rhs_xpr; }
226
227  protected:
228    typename LhsXprType::Nested m_lhs_xpr;
229    typename RhsXprType::Nested m_rhs_xpr;
230    const CustomBinaryFunc m_func;
231};
232
233
234// Eval as rvalue
235template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
236struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
237{
238  typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType;
239  typedef typename internal::traits<XprType>::Index Index;
240  static const int NumDims = internal::traits<XprType>::NumDimensions;
241  typedef DSizes<Index, NumDims> Dimensions;
242  typedef typename XprType::Scalar Scalar;
243  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
244  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
245  static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
246
247  enum {
248    IsAligned = false,
249    PacketAccess = (internal::packet_traits<Scalar>::size > 1),
250    BlockAccess = false,
251    Layout = TensorEvaluator<LhsXprType, Device>::Layout,
252    CoordAccess = false,  // to be implemented
253    RawAccess = false
254  };
255
256  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
257      : m_op(op), m_device(device), m_result(NULL)
258  {
259    m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
260  }
261
262  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
263
264  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
265    if (data) {
266      evalTo(data);
267      return false;
268    } else {
269      m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
270      evalTo(m_result);
271      return true;
272    }
273  }
274
275  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
276    if (m_result != NULL) {
277      m_device.deallocate(m_result);
278      m_result = NULL;
279    }
280  }
281
282  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
283    return m_result[index];
284  }
285
286  template<int LoadMode>
287  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
288    return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
289  }
290
291  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
292    // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
293    return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
294  }
295
296  EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
297
298 protected:
299  EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
300    TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
301    m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
302  }
303
304  Dimensions m_dimensions;
305  const XprType m_op;
306  const Device& m_device;
307  CoeffReturnType* m_result;
308};
309
310
311} // end namespace Eigen
312
313#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
314