1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_
17#define TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_
18
19#include <iostream>
20#include <iterator>
21
22namespace tensorflow {
23
24template <typename StoreType, typename InputType, typename ConversionOp,
25          typename OffsetT = ptrdiff_t>
26class TransformOutputIterator {
27 protected:
28  // Proxy object
29  struct Reference {
30    StoreType* ptr;
31    ConversionOp conversion_op;
32
33    /// Constructor
34    __host__ __device__ __forceinline__ Reference(StoreType* ptr,
35                                                  ConversionOp conversion_op)
36        : ptr(ptr), conversion_op(conversion_op) {}
37
38    /// Assignment
39    __host__ __device__ __forceinline__ InputType operator=(InputType val) {
40      *ptr = conversion_op(val);
41      return val;
42    }
43  };
44
45 public:
46  // Required iterator traits
47  typedef TransformOutputIterator self_type;  ///< My own type
48  typedef OffsetT difference_type;            ///< Type to express the result of
49                                    ///< subtracting one iterator from another
50  typedef void
51      value_type;        ///< The type of the element the iterator can point to
52  typedef void pointer;  ///< The type of a pointer to an element the iterator
53                         ///< can point to
54  typedef Reference reference;  ///< The type of a reference to an element the
55                                ///< iterator can point to
56
57  typedef std::random_access_iterator_tag
58      iterator_category;  ///< The iterator category
59
60  /*private:*/
61
62  StoreType* ptr;
63  ConversionOp conversion_op;
64
65 public:
66  /// Constructor
67  template <typename QualifiedStoreType>
68  __host__ __device__ __forceinline__ TransformOutputIterator(
69      QualifiedStoreType* ptr,
70      ConversionOp conversionOp)  ///< Native pointer to wrap
71      : ptr(ptr), conversion_op(conversionOp) {}
72
73  /// Postfix increment
74  __host__ __device__ __forceinline__ self_type operator++(int) {
75    self_type retval = *this;
76    ptr++;
77    return retval;
78  }
79
80  /// Prefix increment
81  __host__ __device__ __forceinline__ self_type operator++() {
82    ptr++;
83    return *this;
84  }
85
86  /// Indirection
87  __host__ __device__ __forceinline__ reference operator*() const {
88    return Reference(ptr, conversion_op);
89  }
90
91  /// Addition
92  template <typename Distance>
93  __host__ __device__ __forceinline__ self_type operator+(Distance n) const {
94    self_type retval(ptr + n, conversion_op);
95    return retval;
96  }
97
98  /// Addition assignment
99  template <typename Distance>
100  __host__ __device__ __forceinline__ self_type& operator+=(Distance n) {
101    ptr += n;
102    return *this;
103  }
104
105  /// Subtraction
106  template <typename Distance>
107  __host__ __device__ __forceinline__ self_type operator-(Distance n) const {
108    self_type retval(ptr - n, conversion_op);
109    return retval;
110  }
111
112  /// Subtraction assignment
113  template <typename Distance>
114  __host__ __device__ __forceinline__ self_type& operator-=(Distance n) {
115    ptr -= n;
116    return *this;
117  }
118
119  /// Distance
120  __host__ __device__ __forceinline__ difference_type
121  operator-(self_type other) const {
122    return ptr - other.ptr;
123  }
124
125  /// Array subscript
126  template <typename Distance>
127  __host__ __device__ __forceinline__ reference operator[](Distance n) const {
128    return Reference(ptr + n, conversion_op);
129  }
130
131  /// Equal to
132  __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) {
133    return (ptr == rhs.ptr);
134  }
135
136  /// Not equal to
137  __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) {
138    return (ptr != rhs.ptr);
139  }
140
141  /// ostream operator
142  friend std::ostream& operator<<(std::ostream& os, const self_type& itr) {
143    return os;
144  }
145};
146
147}  // end namespace tensorflow
148
149#endif  // TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_
150