1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5//                    Benoit Steiner <benoit.steiner.goog@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12#define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
13
14namespace Eigen {
15namespace internal {
16
17/** \class TensorIndexTuple
18  * \ingroup CXX11_Tensor_Module
19  *
20  * \brief Tensor + Index Tuple class.
21  *
22  *
23  */
24template<typename XprType>
25struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
26{
27  typedef traits<XprType> XprTraits;
28  typedef typename XprTraits::StorageKind StorageKind;
29  typedef typename XprTraits::Index Index;
30  typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
31  typedef typename XprType::Nested Nested;
32  typedef typename remove_reference<Nested>::type _Nested;
33  static const int NumDimensions = XprTraits::NumDimensions;
34  static const int Layout = XprTraits::Layout;
35};
36
37template<typename XprType>
38struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
39{
40  typedef const TensorIndexTupleOp<XprType>& type;
41};
42
43template<typename XprType>
44struct nested<TensorIndexTupleOp<XprType>, 1,
45              typename eval<TensorIndexTupleOp<XprType> >::type>
46{
47  typedef TensorIndexTupleOp<XprType> type;
48};
49
50}  // end namespace internal
51
52template<typename XprType>
53class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
54{
55  public:
56  typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
57  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58  typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
59  typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
60  typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
61  typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
62
63  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr)
64      : m_xpr(expr) {}
65
66  EIGEN_DEVICE_FUNC
67  const typename internal::remove_all<typename XprType::Nested>::type&
68  expression() const { return m_xpr; }
69
70  protected:
71    typename XprType::Nested m_xpr;
72};
73
74// Eval as rvalue
75template<typename ArgType, typename Device>
76struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
77{
78  typedef TensorIndexTupleOp<ArgType> XprType;
79  typedef typename XprType::Index Index;
80  typedef typename XprType::Scalar Scalar;
81  typedef typename XprType::CoeffReturnType CoeffReturnType;
82
83  typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
84  static const int NumDims = internal::array_size<Dimensions>::value;
85
86  enum {
87    IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
88    PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
89    BlockAccess = false,
90    Layout = TensorEvaluator<ArgType, Device>::Layout,
91    CoordAccess = false,  // to be implemented
92    RawAccess = false
93  };
94
95  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
96      : m_impl(op.expression(), device) { }
97
98  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
99    return m_impl.dimensions();
100  }
101
102  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
103    m_impl.evalSubExprsIfNeeded(NULL);
104    return true;
105  }
106  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
107    m_impl.cleanup();
108  }
109
110  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
111  {
112    return CoeffReturnType(index, m_impl.coeff(index));
113  }
114
115  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
116  costPerCoeff(bool vectorized) const {
117    return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
118  }
119
120  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
121
122 protected:
123  TensorEvaluator<ArgType, Device> m_impl;
124};
125
126namespace internal {
127
128/** \class TensorTupleIndex
129  * \ingroup CXX11_Tensor_Module
130  *
131  * \brief Converts to Tensor<Tuple<Index, Scalar> > and reduces to Tensor<Index>.
132  *
133  */
134template<typename ReduceOp, typename Dims, typename XprType>
135struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
136{
137  typedef traits<XprType> XprTraits;
138  typedef typename XprTraits::StorageKind StorageKind;
139  typedef typename XprTraits::Index Index;
140  typedef Index Scalar;
141  typedef typename XprType::Nested Nested;
142  typedef typename remove_reference<Nested>::type _Nested;
143  static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
144  static const int Layout = XprTraits::Layout;
145};
146
147template<typename ReduceOp, typename Dims, typename XprType>
148struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
149{
150  typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type;
151};
152
153template<typename ReduceOp, typename Dims, typename XprType>
154struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
155              typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
156{
157  typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
158};
159
160}  // end namespace internal
161
162template<typename ReduceOp, typename Dims, typename XprType>
163class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
164{
165  public:
166  typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
167  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
168  typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
169  typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
170  typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
171  typedef Index CoeffReturnType;
172
173  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
174                                                          const ReduceOp& reduce_op,
175                                                          const int return_dim,
176                                                          const Dims& reduce_dims)
177      : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
178
179  EIGEN_DEVICE_FUNC
180  const typename internal::remove_all<typename XprType::Nested>::type&
181  expression() const { return m_xpr; }
182
183  EIGEN_DEVICE_FUNC
184  const ReduceOp& reduce_op() const { return m_reduce_op; }
185
186  EIGEN_DEVICE_FUNC
187  const Dims& reduce_dims() const { return m_reduce_dims; }
188
189  EIGEN_DEVICE_FUNC
190  int return_dim() const { return m_return_dim; }
191
192  protected:
193    typename XprType::Nested m_xpr;
194    const ReduceOp m_reduce_op;
195    const int m_return_dim;
196    const Dims m_reduce_dims;
197};
198
199// Eval as rvalue
200template<typename ReduceOp, typename Dims, typename ArgType, typename Device>
201struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
202{
203  typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
204  typedef typename XprType::Index Index;
205  typedef typename XprType::Scalar Scalar;
206  typedef typename XprType::CoeffReturnType CoeffReturnType;
207  typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
208  typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
209  typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
210  static const int NumDims = internal::array_size<InputDimensions>::value;
211  typedef array<Index, NumDims> StrideDims;
212
213  enum {
214    IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
215    PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
216    BlockAccess = false,
217    Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
218    CoordAccess = false,  // to be implemented
219    RawAccess = false
220  };
221
222  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
223      : m_orig_impl(op.expression(), device),
224        m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
225        m_return_dim(op.return_dim()) {
226
227    gen_strides(m_orig_impl.dimensions(), m_strides);
228    if (Layout == static_cast<int>(ColMajor)) {
229      const Index total_size = internal::array_prod(m_orig_impl.dimensions());
230      m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
231    } else {
232      const Index total_size = internal::array_prod(m_orig_impl.dimensions());
233      m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
234    }
235    m_stride_div = m_strides[m_return_dim];
236  }
237
238  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
239    return m_impl.dimensions();
240  }
241
242  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
243    m_impl.evalSubExprsIfNeeded(NULL);
244    return true;
245  }
246  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
247    m_impl.cleanup();
248  }
249
250  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
251    const TupleType v = m_impl.coeff(index);
252    return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
253  }
254
255  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
256
257  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
258  costPerCoeff(bool vectorized) const {
259    const double compute_cost = 1.0 +
260        (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
261    return m_orig_impl.costPerCoeff(vectorized) +
262           m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
263  }
264
265 private:
266  EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
267    if (m_return_dim < 0) {
268      return;  // Won't be using the strides.
269    }
270    eigen_assert(m_return_dim < NumDims &&
271                 "Asking to convert index to a dimension outside of the rank");
272
273    // Calculate m_stride_div and m_stride_mod, which are used to
274    // calculate the value of an index w.r.t. the m_return_dim.
275    if (Layout == static_cast<int>(ColMajor)) {
276      strides[0] = 1;
277      for (int i = 1; i < NumDims; ++i) {
278        strides[i] = strides[i-1] * dims[i-1];
279      }
280    } else {
281      strides[NumDims-1] = 1;
282      for (int i = NumDims - 2; i >= 0; --i) {
283        strides[i] = strides[i+1] * dims[i+1];
284      }
285    }
286  }
287
288 protected:
289  TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
290  TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
291  const int m_return_dim;
292  StrideDims m_strides;
293  Index m_stride_mod;
294  Index m_stride_div;
295};
296
297} // end namespace Eigen
298
299#endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
300