1// This file is part of Eigen, a lightweight C++ template library 2// for linear algebra. 3// 4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5// 6// This Source Code Form is subject to the terms of the Mozilla 7// Public License v. 2.0. If a copy of the MPL was not distributed 8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 12 13namespace Eigen { 14 15/** \class TensorCustomUnaryOp 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor custom class. 19 * 20 * 21 */ 22namespace internal { 23template<typename CustomUnaryFunc, typename XprType> 24struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > 25{ 26 typedef typename XprType::Scalar Scalar; 27 typedef typename XprType::StorageKind StorageKind; 28 typedef typename XprType::Index Index; 29 typedef typename XprType::Nested Nested; 30 typedef typename remove_reference<Nested>::type _Nested; 31 static const int NumDimensions = traits<XprType>::NumDimensions; 32 static const int Layout = traits<XprType>::Layout; 33}; 34 35template<typename CustomUnaryFunc, typename XprType> 36struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense> 37{ 38 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type; 39}; 40 41template<typename CustomUnaryFunc, typename XprType> 42struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > 43{ 44 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type; 45}; 46 47} // end namespace internal 48 49 50 51template<typename CustomUnaryFunc, typename XprType> 52class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> 53{ 54 public: 55 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar; 56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 57 typedef typename XprType::CoeffReturnType CoeffReturnType; 58 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested; 59 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind; 60 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index; 61 62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func) 63 : m_expr(expr), m_func(func) {} 64 65 EIGEN_DEVICE_FUNC 66 const CustomUnaryFunc& func() const { return m_func; } 67 68 EIGEN_DEVICE_FUNC 69 const typename internal::remove_all<typename XprType::Nested>::type& 70 expression() const { return m_expr; } 71 72 protected: 73 typename XprType::Nested m_expr; 74 const CustomUnaryFunc m_func; 75}; 76 77 78// Eval as rvalue 79template<typename CustomUnaryFunc, typename XprType, typename Device> 80struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device> 81{ 82 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType; 83 typedef typename internal::traits<ArgType>::Index Index; 84 static const int NumDims = internal::traits<ArgType>::NumDimensions; 85 typedef DSizes<Index, NumDims> Dimensions; 86 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar; 87 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; 88 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 89 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 90 91 enum { 92 IsAligned = false, 93 PacketAccess = (internal::packet_traits<Scalar>::size > 1), 94 BlockAccess = false, 95 Layout = TensorEvaluator<XprType, Device>::Layout, 96 CoordAccess = false, // to be implemented 97 RawAccess = false 98 }; 99 100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device) 101 : m_op(op), m_device(device), m_result(NULL) 102 { 103 m_dimensions = op.func().dimensions(op.expression()); 104 } 105 106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 107 108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 109 if (data) { 110 evalTo(data); 111 return false; 112 } else { 113 m_result = static_cast<CoeffReturnType*>( 114 m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 115 evalTo(m_result); 116 return true; 117 } 118 } 119 120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 121 if (m_result != NULL) { 122 m_device.deallocate(m_result); 123 m_result = NULL; 124 } 125 } 126 127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 128 return m_result[index]; 129 } 130 131 template<int LoadMode> 132 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { 133 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 134 } 135 136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 137 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. 138 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); 139 } 140 141 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } 142 143 protected: 144 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { 145 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result( 146 data, m_dimensions); 147 m_op.func().eval(m_op.expression(), result, m_device); 148 } 149 150 Dimensions m_dimensions; 151 const ArgType m_op; 152 const Device& m_device; 153 CoeffReturnType* m_result; 154}; 155 156 157 158/** \class TensorCustomBinaryOp 159 * \ingroup CXX11_Tensor_Module 160 * 161 * \brief Tensor custom class. 162 * 163 * 164 */ 165namespace internal { 166template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 167struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > 168{ 169 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, 170 typename RhsXprType::Scalar>::ret Scalar; 171 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, 172 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; 173 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 174 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 175 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 176 typename traits<RhsXprType>::Index>::type Index; 177 typedef typename LhsXprType::Nested LhsNested; 178 typedef typename RhsXprType::Nested RhsNested; 179 typedef typename remove_reference<LhsNested>::type _LhsNested; 180 typedef typename remove_reference<RhsNested>::type _RhsNested; 181 static const int NumDimensions = traits<LhsXprType>::NumDimensions; 182 static const int Layout = traits<LhsXprType>::Layout; 183}; 184 185template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 186struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense> 187{ 188 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type; 189}; 190 191template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 192struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > 193{ 194 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type; 195}; 196 197} // end namespace internal 198 199 200 201template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 202class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors> 203{ 204 public: 205 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar; 206 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 207 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType; 208 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested; 209 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind; 210 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index; 211 212 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func) 213 214 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {} 215 216 EIGEN_DEVICE_FUNC 217 const CustomBinaryFunc& func() const { return m_func; } 218 219 EIGEN_DEVICE_FUNC 220 const typename internal::remove_all<typename LhsXprType::Nested>::type& 221 lhsExpression() const { return m_lhs_xpr; } 222 223 EIGEN_DEVICE_FUNC 224 const typename internal::remove_all<typename RhsXprType::Nested>::type& 225 rhsExpression() const { return m_rhs_xpr; } 226 227 protected: 228 typename LhsXprType::Nested m_lhs_xpr; 229 typename RhsXprType::Nested m_rhs_xpr; 230 const CustomBinaryFunc m_func; 231}; 232 233 234// Eval as rvalue 235template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device> 236struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device> 237{ 238 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType; 239 typedef typename internal::traits<XprType>::Index Index; 240 static const int NumDims = internal::traits<XprType>::NumDimensions; 241 typedef DSizes<Index, NumDims> Dimensions; 242 typedef typename XprType::Scalar Scalar; 243 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; 244 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 245 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 246 247 enum { 248 IsAligned = false, 249 PacketAccess = (internal::packet_traits<Scalar>::size > 1), 250 BlockAccess = false, 251 Layout = TensorEvaluator<LhsXprType, Device>::Layout, 252 CoordAccess = false, // to be implemented 253 RawAccess = false 254 }; 255 256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 257 : m_op(op), m_device(device), m_result(NULL) 258 { 259 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression()); 260 } 261 262 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 263 264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 265 if (data) { 266 evalTo(data); 267 return false; 268 } else { 269 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 270 evalTo(m_result); 271 return true; 272 } 273 } 274 275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 276 if (m_result != NULL) { 277 m_device.deallocate(m_result); 278 m_result = NULL; 279 } 280 } 281 282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 283 return m_result[index]; 284 } 285 286 template<int LoadMode> 287 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { 288 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 289 } 290 291 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 292 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. 293 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); 294 } 295 296 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } 297 298 protected: 299 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { 300 TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions); 301 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); 302 } 303 304 Dimensions m_dimensions; 305 const XprType m_op; 306 const Device& m_device; 307 CoeffReturnType* m_result; 308}; 309 310 311} // end namespace Eigen 312 313#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 314