1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
11#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
12
13namespace Eigen {
14
15namespace internal {
16
17template <typename Dimensions, typename Scalar>
18class TensorLazyBaseEvaluator {
19 public:
20  TensorLazyBaseEvaluator() : m_refcount(0) { }
21  virtual ~TensorLazyBaseEvaluator() { }
22
23  EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
24  EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
25
26  EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
27  EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
28
29  void incrRefCount() { ++m_refcount; }
30  void decrRefCount() { --m_refcount; }
31  int refCount() const { return m_refcount; }
32
33 private:
34  // No copy, no assigment;
35  TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
36  TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
37
38  int m_refcount;
39};
40
41
42template <typename Dimensions, typename Expr, typename Device>
43class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
44 public:
45  //  typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
46  typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
47
48  TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
49    m_dims = m_impl.dimensions();
50    m_impl.evalSubExprsIfNeeded(NULL);
51  }
52  virtual ~TensorLazyEvaluatorReadOnly() {
53    m_impl.cleanup();
54  }
55
56  EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const {
57    return m_dims;
58  }
59  EIGEN_DEVICE_FUNC virtual const Scalar* data() const {
60    return m_impl.data();
61  }
62
63  EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const {
64    return m_impl.coeff(index);
65  }
66  EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
67    eigen_assert(false && "can't reference the coefficient of a rvalue");
68    return m_dummy;
69  };
70
71 protected:
72  TensorEvaluator<Expr, Device> m_impl;
73  Dimensions m_dims;
74  Scalar m_dummy;
75};
76
77template <typename Dimensions, typename Expr, typename Device>
78class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
79 public:
80  typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
81  typedef typename Base::Scalar Scalar;
82
83  TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
84  }
85  virtual ~TensorLazyEvaluatorWritable() {
86  }
87
88  EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) {
89    return this->m_impl.coeffRef(index);
90  }
91};
92
93template <typename Dimensions, typename Expr, typename Device>
94class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
95                            TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
96                            TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
97 public:
98  typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
99                                         TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
100                                         TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
101  typedef typename Base::Scalar Scalar;
102
103  TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
104  }
105  virtual ~TensorLazyEvaluator() {
106  }
107};
108
109}  // namespace internal
110
111
112/** \class TensorRef
113  * \ingroup CXX11_Tensor_Module
114  *
115  * \brief A reference to a tensor expression
116  * The expression will be evaluated lazily (as much as possible).
117  *
118  */
119template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
120{
121  public:
122    typedef TensorRef<PlainObjectType> Self;
123    typedef typename PlainObjectType::Base Base;
124    typedef typename Eigen::internal::nested<Self>::type Nested;
125    typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
126    typedef typename internal::traits<PlainObjectType>::Index Index;
127    typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
128    typedef typename NumTraits<Scalar>::Real RealScalar;
129    typedef typename Base::CoeffReturnType CoeffReturnType;
130    typedef Scalar* PointerType;
131    typedef PointerType PointerArgType;
132
133    static const Index NumIndices = PlainObjectType::NumIndices;
134    typedef typename PlainObjectType::Dimensions Dimensions;
135
136    enum {
137      IsAligned = false,
138      PacketAccess = false,
139      Layout = PlainObjectType::Layout,
140      CoordAccess = false,  // to be implemented
141      RawAccess = false
142    };
143
144    EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
145    }
146
147    template <typename Expression>
148    EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
149      m_evaluator->incrRefCount();
150    }
151
152    template <typename Expression>
153    EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
154      unrefEvaluator();
155      m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
156      m_evaluator->incrRefCount();
157      return *this;
158    }
159
160    ~TensorRef() {
161      unrefEvaluator();
162    }
163
164    TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
165      eigen_assert(m_evaluator->refCount() > 0);
166      m_evaluator->incrRefCount();
167    }
168
169    TensorRef& operator = (const TensorRef& other) {
170      if (this != &other) {
171        unrefEvaluator();
172        m_evaluator = other.m_evaluator;
173        eigen_assert(m_evaluator->refCount() > 0);
174        m_evaluator->incrRefCount();
175      }
176      return *this;
177    }
178
179    EIGEN_DEVICE_FUNC
180    EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
181    EIGEN_DEVICE_FUNC
182    EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
183    EIGEN_DEVICE_FUNC
184    EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
185    EIGEN_DEVICE_FUNC
186    EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
187    EIGEN_DEVICE_FUNC
188    EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
189
190    EIGEN_DEVICE_FUNC
191    EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
192    {
193      return m_evaluator->coeff(index);
194    }
195
196#if EIGEN_HAS_VARIADIC_TEMPLATES
197    template<typename... IndexTypes> EIGEN_DEVICE_FUNC
198    EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
199    {
200      const std::size_t num_indices = (sizeof...(otherIndices) + 1);
201      const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
202      return coeff(indices);
203    }
204    template<typename... IndexTypes> EIGEN_DEVICE_FUNC
205    EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
206    {
207      const std::size_t num_indices = (sizeof...(otherIndices) + 1);
208      const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
209      return coeffRef(indices);
210    }
211#else
212
213    EIGEN_DEVICE_FUNC
214    EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
215    {
216      array<Index, 2> indices;
217      indices[0] = i0;
218      indices[1] = i1;
219      return coeff(indices);
220    }
221    EIGEN_DEVICE_FUNC
222    EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
223    {
224      array<Index, 3> indices;
225      indices[0] = i0;
226      indices[1] = i1;
227      indices[2] = i2;
228      return coeff(indices);
229    }
230    EIGEN_DEVICE_FUNC
231    EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
232    {
233      array<Index, 4> indices;
234      indices[0] = i0;
235      indices[1] = i1;
236      indices[2] = i2;
237      indices[3] = i3;
238      return coeff(indices);
239    }
240    EIGEN_DEVICE_FUNC
241    EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
242    {
243      array<Index, 5> indices;
244      indices[0] = i0;
245      indices[1] = i1;
246      indices[2] = i2;
247      indices[3] = i3;
248      indices[4] = i4;
249      return coeff(indices);
250    }
251    EIGEN_DEVICE_FUNC
252    EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
253    {
254      array<Index, 2> indices;
255      indices[0] = i0;
256      indices[1] = i1;
257      return coeffRef(indices);
258    }
259    EIGEN_DEVICE_FUNC
260    EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
261    {
262      array<Index, 3> indices;
263      indices[0] = i0;
264      indices[1] = i1;
265      indices[2] = i2;
266      return coeffRef(indices);
267    }
268    EIGEN_DEVICE_FUNC
269    EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
270    {
271      array<Index, 4> indices;
272      indices[0] = i0;
273      indices[1] = i1;
274      indices[2] = i2;
275      indices[3] = i3;
276      return coeffRef(indices);
277    }
278    EIGEN_DEVICE_FUNC
279    EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
280    {
281      array<Index, 5> indices;
282      indices[0] = i0;
283      indices[1] = i1;
284      indices[2] = i2;
285      indices[3] = i3;
286      indices[4] = i4;
287      return coeffRef(indices);
288    }
289#endif
290
291    template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
292    EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
293    {
294      const Dimensions& dims = this->dimensions();
295      Index index = 0;
296      if (PlainObjectType::Options & RowMajor) {
297        index += indices[0];
298        for (size_t i = 1; i < NumIndices; ++i) {
299          index = index * dims[i] + indices[i];
300        }
301      } else {
302        index += indices[NumIndices-1];
303        for (int i = NumIndices-2; i >= 0; --i) {
304          index = index * dims[i] + indices[i];
305        }
306      }
307      return m_evaluator->coeff(index);
308    }
309    template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
310    EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
311    {
312      const Dimensions& dims = this->dimensions();
313      Index index = 0;
314      if (PlainObjectType::Options & RowMajor) {
315        index += indices[0];
316        for (size_t i = 1; i < NumIndices; ++i) {
317          index = index * dims[i] + indices[i];
318        }
319      } else {
320        index += indices[NumIndices-1];
321        for (int i = NumIndices-2; i >= 0; --i) {
322          index = index * dims[i] + indices[i];
323        }
324      }
325      return m_evaluator->coeffRef(index);
326    }
327
328    EIGEN_DEVICE_FUNC
329    EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
330    {
331      return m_evaluator->coeff(index);
332    }
333
334    EIGEN_DEVICE_FUNC
335    EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
336    {
337      return m_evaluator->coeffRef(index);
338    }
339
340  private:
341    EIGEN_STRONG_INLINE void unrefEvaluator() {
342      if (m_evaluator) {
343        m_evaluator->decrRefCount();
344        if (m_evaluator->refCount() == 0) {
345          delete m_evaluator;
346        }
347      }
348    }
349
350  internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
351};
352
353
354// evaluator for rvalues
355template<typename Derived, typename Device>
356struct TensorEvaluator<const TensorRef<Derived>, Device>
357{
358  typedef typename Derived::Index Index;
359  typedef typename Derived::Scalar Scalar;
360  typedef typename Derived::Scalar CoeffReturnType;
361  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
362  typedef typename Derived::Dimensions Dimensions;
363
364  enum {
365    IsAligned = false,
366    PacketAccess = false,
367    Layout = TensorRef<Derived>::Layout,
368    CoordAccess = false,  // to be implemented
369    RawAccess = false
370  };
371
372  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
373      : m_ref(m)
374  { }
375
376  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
377
378  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
379    return true;
380  }
381
382  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
383
384  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
385    return m_ref.coeff(index);
386  }
387
388  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
389    return m_ref.coeffRef(index);
390  }
391
392  EIGEN_DEVICE_FUNC Scalar* data() const { return m_ref.data(); }
393
394 protected:
395  TensorRef<Derived> m_ref;
396};
397
398
399// evaluator for lvalues
400template<typename Derived, typename Device>
401struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
402{
403  typedef typename Derived::Index Index;
404  typedef typename Derived::Scalar Scalar;
405  typedef typename Derived::Scalar CoeffReturnType;
406  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
407  typedef typename Derived::Dimensions Dimensions;
408
409  typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
410
411  enum {
412    IsAligned = false,
413    PacketAccess = false,
414    RawAccess = false
415  };
416
417  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
418  { }
419
420  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
421    return this->m_ref.coeffRef(index);
422  }
423};
424
425
426
427} // end namespace Eigen
428
429#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
430