1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Mehdi Goli    Codeplay Software Ltd.
5// Ralph Potter  Codeplay Software Ltd.
6// Luke Iwanski  Codeplay Software Ltd.
7// Contact: <eigen@codeplay.com>
8//
9// This Source Code Form is subject to the terms of the Mozilla
10// Public License v. 2.0. If a copy of the MPL was not distributed
11// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12
13/*****************************************************************
14 * TensorSyclextractFunctors.h
15 *
16 * \brief:
17 *  Used to extract all the functors allocated to each node of the expression
18*tree.
19 *
20*****************************************************************/
21
22#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
23#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
24
25namespace Eigen {
26namespace TensorSycl {
27namespace internal {
28/// \struct FunctorExtractor:  This struct is used to extract the functors
29/// constructed on
30/// the host-side, to pack them and reuse them in reconstruction of the
31/// expression on the device.
32/// We have to do that as in Eigen the functors are not stateless so we cannot
33/// re-instantiate them on the device.
34/// We have to pass instantiated functors to the device.
35// This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval).
36template <typename Evaluator> struct FunctorExtractor{
37  typedef typename Evaluator::Dimensions Dimensions;
38  const Dimensions m_dimensions;
39  const Dimensions& dimensions() const { return m_dimensions; }
40  FunctorExtractor(const Evaluator& expr)
41  : m_dimensions(expr.dimensions()) {}
42
43};
44
45/// specialisation of the \ref FunctorExtractor struct when the node type is
46/// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp
47template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
48struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > {
49  FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
50  OP func;
51  FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr)
52  : rhsExpr(expr.impl()), func(expr.functor()) {}
53};
54/// specialisation of the \ref FunctorExtractor struct when the node type is
55/// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp
56template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
57struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> >
58: FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{};
59
60/// specialisation of the \ref FunctorExtractor struct when the node type is
61/// const TensorCwiseBinaryOp
62template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
63struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > {
64  FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
65  FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
66  OP func;
67  FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr)
68  : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}
69};
70
71/// specialisation of the \ref FunctorExtractor struct when the node type is
72/// const TensorCwiseBinaryOp
73template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
74struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP,  LHSExpr, RHSExpr>, Dev> >
75: FunctorExtractor<TensorEvaluator<const BinaryCategory<OP,  LHSExpr, RHSExpr>, Dev> >{};
76
77/// specialisation of the \ref FunctorExtractor struct when the node type is
78/// const TensorCwiseTernaryOp
79template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev>
80struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > {
81  FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr;
82  FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr;
83  FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr;
84  OP func;
85  FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
86  : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}
87};
88
89/// specialisation of the \ref FunctorExtractor struct when the node type is
90/// TensorCwiseTernaryOp
91template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev>
92struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >
93:FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{};
94
95/// specialisation of the \ref FunctorExtractor struct when the node type is
96/// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated.
97template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
98struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {
99  FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr;
100  FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr;
101  FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr;
102  FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
103  : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {}
104};
105
106/// specialisation of the \ref FunctorExtractor struct when the node type is
107/// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated
108template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
109struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> >
110:FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {};
111
112/// specialisation of the \ref FunctorExtractor struct when the node type is
113/// const TensorAssignOp. This is an specialisation without OP so it has to be separated.
114template <typename LHSExpr, typename RHSExpr, typename Dev>
115struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > {
116  FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
117  FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
118  FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
119  : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
120};
121
122/// specialisation of the \ref FunctorExtractor struct when the node type is
123/// TensorAssignOp. This is an specialisation without OP so it has to be separated.
124template <typename LHSExpr, typename RHSExpr, typename Dev>
125struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> >
126:FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{};
127
128
129/// specialisation of the \ref FunctorExtractor struct when the node type is
130/// const TensorEvalToOp, This is an specialisation without OP so it has to be separated.
131template <typename RHSExpr, typename Dev>
132struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {
133  FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
134  FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
135  : rhsExpr(expr.impl()) {}
136};
137
138/// specialisation of the \ref FunctorExtractor struct when the node type is
139/// TensorEvalToOp. This is a specialisation without OP so it has to be separated.
140template <typename RHSExpr, typename Dev>
141struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> >
142: FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {};
143
144template<typename Dim, size_t NumOutputDim> struct DimConstr {
145template<typename InDim>
146  static inline Dim getDim(InDim dims ) {return dims;}
147};
148
149template<typename Dim> struct DimConstr<Dim, 0> {
150  template<typename InDim>
151    static inline Dim getDim(InDim dims ) {return Dim(dims.TotalSize());}
152};
153
154template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
155struct FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{
156  typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device> Evaluator;
157  typedef typename Eigen::internal::conditional<Evaluator::NumOutputDims==0, DSizes<typename Evaluator::Index, 1>, typename Evaluator::Dimensions >::type Dimensions;
158  const Dimensions m_dimensions;
159  const Dimensions& dimensions() const { return m_dimensions; }
160  FunctorExtractor(const TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>& expr)
161  : m_dimensions(DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.dimensions())) {}
162};
163
164
165template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
166struct FunctorExtractor<TensorEvaluator<TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>
167: FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{};
168/// template deduction function for FunctorExtractor
169template <typename Evaluator>
170auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> {
171  return FunctorExtractor<Evaluator>(evaluator);
172}
173}  // namespace internal
174}  // namespace TensorSycl
175}  // namespace Eigen
176
177#endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
178