1// This file is part of Eigen, a lightweight C++ template library 2// for linear algebra. 3// 4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5// 6// This Source Code Form is subject to the terms of the Mozilla 7// Public License v. 2.0. If a copy of the MPL was not distributed 8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10#ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 11#define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 12 13namespace Eigen { 14 15/** \class TensorExpr 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor expression classes. 19 * 20 * The TensorCwiseNullaryOp class applies a nullary operators to an expression. 21 * This is typically used to generate constants. 22 * 23 * The TensorCwiseUnaryOp class represents an expression where a unary operator 24 * (e.g. cwiseSqrt) is applied to an expression. 25 * 26 * The TensorCwiseBinaryOp class represents an expression where a binary 27 * operator (e.g. addition) is applied to a lhs and a rhs expression. 28 * 29 */ 30namespace internal { 31template<typename NullaryOp, typename XprType> 32struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > 33 : traits<XprType> 34{ 35 typedef traits<XprType> XprTraits; 36 typedef typename XprType::Scalar Scalar; 37 typedef typename XprType::Nested XprTypeNested; 38 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 39 static const int NumDimensions = XprTraits::NumDimensions; 40 static const int Layout = XprTraits::Layout; 41 42 enum { 43 Flags = 0 44 }; 45}; 46 47} // end namespace internal 48 49 50 51template<typename NullaryOp, typename XprType> 52class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> 53{ 54 public: 55 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; 56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 57 typedef typename XprType::CoeffReturnType CoeffReturnType; 58 typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested; 59 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; 60 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; 61 62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) 63 : m_xpr(xpr), m_functor(func) {} 64 65 EIGEN_DEVICE_FUNC 66 const typename internal::remove_all<typename XprType::Nested>::type& 67 nestedExpression() const { return m_xpr; } 68 69 EIGEN_DEVICE_FUNC 70 const NullaryOp& functor() const { return m_functor; } 71 72 protected: 73 typename XprType::Nested m_xpr; 74 const NullaryOp m_functor; 75}; 76 77 78 79namespace internal { 80template<typename UnaryOp, typename XprType> 81struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > 82 : traits<XprType> 83{ 84 // TODO(phli): Add InputScalar, InputPacket. Check references to 85 // current Scalar/Packet to see if the intent is Input or Output. 86 typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar; 87 typedef traits<XprType> XprTraits; 88 typedef typename XprType::Nested XprTypeNested; 89 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 90 static const int NumDimensions = XprTraits::NumDimensions; 91 static const int Layout = XprTraits::Layout; 92}; 93 94template<typename UnaryOp, typename XprType> 95struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense> 96{ 97 typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type; 98}; 99 100template<typename UnaryOp, typename XprType> 101struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type> 102{ 103 typedef TensorCwiseUnaryOp<UnaryOp, XprType> type; 104}; 105 106} // end namespace internal 107 108 109 110template<typename UnaryOp, typename XprType> 111class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> 112{ 113 public: 114 // TODO(phli): Add InputScalar, InputPacket. Check references to 115 // current Scalar/Packet to see if the intent is Input or Output. 116 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar; 117 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 118 typedef Scalar CoeffReturnType; 119 typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested; 120 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind; 121 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index; 122 123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp()) 124 : m_xpr(xpr), m_functor(func) {} 125 126 EIGEN_DEVICE_FUNC 127 const UnaryOp& functor() const { return m_functor; } 128 129 /** \returns the nested expression */ 130 EIGEN_DEVICE_FUNC 131 const typename internal::remove_all<typename XprType::Nested>::type& 132 nestedExpression() const { return m_xpr; } 133 134 protected: 135 typename XprType::Nested m_xpr; 136 const UnaryOp m_functor; 137}; 138 139 140namespace internal { 141template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 142struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > 143{ 144 // Type promotion to handle the case where the types of the lhs and the rhs 145 // are different. 146 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 147 // current Scalar/Packet to see if the intent is Inputs or Output. 148 typedef typename result_of< 149 BinaryOp(typename LhsXprType::Scalar, 150 typename RhsXprType::Scalar)>::type Scalar; 151 typedef traits<LhsXprType> XprTraits; 152 typedef typename promote_storage_type< 153 typename traits<LhsXprType>::StorageKind, 154 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 155 typedef typename promote_index_type< 156 typename traits<LhsXprType>::Index, 157 typename traits<RhsXprType>::Index>::type Index; 158 typedef typename LhsXprType::Nested LhsNested; 159 typedef typename RhsXprType::Nested RhsNested; 160 typedef typename remove_reference<LhsNested>::type _LhsNested; 161 typedef typename remove_reference<RhsNested>::type _RhsNested; 162 static const int NumDimensions = XprTraits::NumDimensions; 163 static const int Layout = XprTraits::Layout; 164 165 enum { 166 Flags = 0 167 }; 168}; 169 170template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 171struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense> 172{ 173 typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type; 174}; 175 176template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 177struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type> 178{ 179 typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type; 180}; 181 182} // end namespace internal 183 184 185 186template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 187class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> 188{ 189 public: 190 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 191 // current Scalar/Packet to see if the intent is Inputs or Output. 192 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; 193 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 194 typedef Scalar CoeffReturnType; 195 typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; 196 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; 197 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; 198 199 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) 200 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} 201 202 EIGEN_DEVICE_FUNC 203 const BinaryOp& functor() const { return m_functor; } 204 205 /** \returns the nested expressions */ 206 EIGEN_DEVICE_FUNC 207 const typename internal::remove_all<typename LhsXprType::Nested>::type& 208 lhsExpression() const { return m_lhs_xpr; } 209 210 EIGEN_DEVICE_FUNC 211 const typename internal::remove_all<typename RhsXprType::Nested>::type& 212 rhsExpression() const { return m_rhs_xpr; } 213 214 protected: 215 typename LhsXprType::Nested m_lhs_xpr; 216 typename RhsXprType::Nested m_rhs_xpr; 217 const BinaryOp m_functor; 218}; 219 220 221namespace internal { 222template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 223struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> > 224{ 225 // Type promotion to handle the case where the types of the args are different. 226 typedef typename result_of< 227 TernaryOp(typename Arg1XprType::Scalar, 228 typename Arg2XprType::Scalar, 229 typename Arg3XprType::Scalar)>::type Scalar; 230 typedef traits<Arg1XprType> XprTraits; 231 typedef typename traits<Arg1XprType>::StorageKind StorageKind; 232 typedef typename traits<Arg1XprType>::Index Index; 233 typedef typename Arg1XprType::Nested Arg1Nested; 234 typedef typename Arg2XprType::Nested Arg2Nested; 235 typedef typename Arg3XprType::Nested Arg3Nested; 236 typedef typename remove_reference<Arg1Nested>::type _Arg1Nested; 237 typedef typename remove_reference<Arg2Nested>::type _Arg2Nested; 238 typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; 239 static const int NumDimensions = XprTraits::NumDimensions; 240 static const int Layout = XprTraits::Layout; 241 242 enum { 243 Flags = 0 244 }; 245}; 246 247template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 248struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense> 249{ 250 typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type; 251}; 252 253template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 254struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type> 255{ 256 typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type; 257}; 258 259} // end namespace internal 260 261 262 263template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 264class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors> 265{ 266 public: 267 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar; 268 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 269 typedef Scalar CoeffReturnType; 270 typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested; 271 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind; 272 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index; 273 274 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) 275 : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} 276 277 EIGEN_DEVICE_FUNC 278 const TernaryOp& functor() const { return m_functor; } 279 280 /** \returns the nested expressions */ 281 EIGEN_DEVICE_FUNC 282 const typename internal::remove_all<typename Arg1XprType::Nested>::type& 283 arg1Expression() const { return m_arg1_xpr; } 284 285 EIGEN_DEVICE_FUNC 286 const typename internal::remove_all<typename Arg2XprType::Nested>::type& 287 arg2Expression() const { return m_arg2_xpr; } 288 289 EIGEN_DEVICE_FUNC 290 const typename internal::remove_all<typename Arg3XprType::Nested>::type& 291 arg3Expression() const { return m_arg3_xpr; } 292 293 protected: 294 typename Arg1XprType::Nested m_arg1_xpr; 295 typename Arg2XprType::Nested m_arg2_xpr; 296 typename Arg3XprType::Nested m_arg3_xpr; 297 const TernaryOp m_functor; 298}; 299 300 301namespace internal { 302template<typename IfXprType, typename ThenXprType, typename ElseXprType> 303struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > 304 : traits<ThenXprType> 305{ 306 typedef typename traits<ThenXprType>::Scalar Scalar; 307 typedef traits<ThenXprType> XprTraits; 308 typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, 309 typename traits<ElseXprType>::StorageKind>::ret StorageKind; 310 typedef typename promote_index_type<typename traits<ElseXprType>::Index, 311 typename traits<ThenXprType>::Index>::type Index; 312 typedef typename IfXprType::Nested IfNested; 313 typedef typename ThenXprType::Nested ThenNested; 314 typedef typename ElseXprType::Nested ElseNested; 315 static const int NumDimensions = XprTraits::NumDimensions; 316 static const int Layout = XprTraits::Layout; 317}; 318 319template<typename IfXprType, typename ThenXprType, typename ElseXprType> 320struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> 321{ 322 typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type; 323}; 324 325template<typename IfXprType, typename ThenXprType, typename ElseXprType> 326struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> 327{ 328 typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type; 329}; 330 331} // end namespace internal 332 333 334template<typename IfXprType, typename ThenXprType, typename ElseXprType> 335class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors> 336{ 337 public: 338 typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; 339 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 340 typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType, 341 typename ElseXprType::CoeffReturnType>::ret CoeffReturnType; 342 typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested; 343 typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind; 344 typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index; 345 346 EIGEN_DEVICE_FUNC 347 TensorSelectOp(const IfXprType& a_condition, 348 const ThenXprType& a_then, 349 const ElseXprType& a_else) 350 : m_condition(a_condition), m_then(a_then), m_else(a_else) 351 { } 352 353 EIGEN_DEVICE_FUNC 354 const IfXprType& ifExpression() const { return m_condition; } 355 356 EIGEN_DEVICE_FUNC 357 const ThenXprType& thenExpression() const { return m_then; } 358 359 EIGEN_DEVICE_FUNC 360 const ElseXprType& elseExpression() const { return m_else; } 361 362 protected: 363 typename IfXprType::Nested m_condition; 364 typename ThenXprType::Nested m_then; 365 typename ElseXprType::Nested m_else; 366}; 367 368 369} // end namespace Eigen 370 371#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 372