1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
17#define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
18
19#include <type_traits>
20#include "tensorflow/core/lib/gtl/flatset.h"
21
22namespace tensorflow {
23namespace gtl {
24
25// CompactPointerSet<T> is like a std::unordered_set<T> but is optimized
26// for small sets (<= 1 element).  T must be a pointer type.
27template <typename T>
28class CompactPointerSet {
29 private:
30  using BigRep = FlatSet<T>;
31
32 public:
33  using value_type = T;
34
35  CompactPointerSet() : rep_(0) {}
36
37  ~CompactPointerSet() {
38    static_assert(
39        std::is_pointer<T>::value,
40        "CompactPointerSet<T> can only be used with T's that are pointers");
41    if (isbig()) delete big();
42  }
43
44  CompactPointerSet(const CompactPointerSet& other) : rep_(0) { *this = other; }
45
46  CompactPointerSet& operator=(const CompactPointerSet& other) {
47    if (this == &other) return *this;
48    if (other.isbig()) {
49      // big => any
50      if (!isbig()) MakeBig();
51      *big() = *other.big();
52    } else if (isbig()) {
53      // !big => big
54      big()->clear();
55      if (other.rep_ != 0) {
56        big()->insert(reinterpret_cast<T>(other.rep_));
57      }
58    } else {
59      // !big => !big
60      rep_ = other.rep_;
61    }
62    return *this;
63  }
64
65  class iterator {
66   public:
67    typedef ssize_t difference_type;
68    typedef T value_type;
69    typedef const T* pointer;
70    typedef const T& reference;
71    typedef ::std::forward_iterator_tag iterator_category;
72
73    explicit iterator(uintptr_t rep)
74        : bigrep_(false), single_(reinterpret_cast<T>(rep)) {}
75    explicit iterator(typename BigRep::iterator iter)
76        : bigrep_(true), single_(nullptr), iter_(iter) {}
77
78    iterator& operator++() {
79      if (bigrep_) {
80        ++iter_;
81      } else {
82        DCHECK(single_ != nullptr);
83        single_ = nullptr;
84      }
85      return *this;
86    }
87    // maybe post-increment?
88
89    bool operator==(const iterator& other) const {
90      if (bigrep_) {
91        return iter_ == other.iter_;
92      } else {
93        return single_ == other.single_;
94      }
95    }
96    bool operator!=(const iterator& other) const { return !(*this == other); }
97
98    const T& operator*() const {
99      if (bigrep_) {
100        return *iter_;
101      } else {
102        DCHECK(single_ != nullptr);
103        return single_;
104      }
105    }
106
107   private:
108    friend class CompactPointerSet;
109    bool bigrep_;
110    T single_;
111    typename BigRep::iterator iter_;
112  };
113  using const_iterator = iterator;
114
115  bool empty() const { return isbig() ? big()->empty() : (rep_ == 0); }
116  size_t size() const { return isbig() ? big()->size() : (rep_ == 0 ? 0 : 1); }
117
118  void clear() {
119    if (isbig()) {
120      delete big();
121    }
122    rep_ = 0;
123  }
124
125  std::pair<iterator, bool> insert(T elem) {
126    if (!isbig()) {
127      if (rep_ == 0) {
128        uintptr_t v = reinterpret_cast<uintptr_t>(elem);
129        if (v == 0 || ((v & 0x3) != 0)) {
130          // Cannot use small representation for nullptr.  Fall through.
131        } else {
132          rep_ = v;
133          return {iterator(v), true};
134        }
135      }
136      MakeBig();
137    }
138    auto p = big()->insert(elem);
139    return {iterator(p.first), p.second};
140  }
141
142  template <typename InputIter>
143  void insert(InputIter begin, InputIter end) {
144    for (; begin != end; ++begin) {
145      insert(*begin);
146    }
147  }
148
149  const_iterator begin() const {
150    return isbig() ? iterator(big()->begin()) : iterator(rep_);
151  }
152  const_iterator end() const {
153    return isbig() ? iterator(big()->end()) : iterator(0);
154  }
155
156  iterator find(T elem) const {
157    if (rep_ == reinterpret_cast<uintptr_t>(elem)) {
158      return iterator(rep_);
159    } else if (!isbig()) {
160      return iterator(0);
161    } else {
162      return iterator(big()->find(elem));
163    }
164  }
165
166  size_t count(T elem) const { return find(elem) != end() ? 1 : 0; }
167
168  size_t erase(T elem) {
169    if (!isbig()) {
170      if (rep_ == reinterpret_cast<uintptr_t>(elem)) {
171        rep_ = 0;
172        return 1;
173      } else {
174        return 0;
175      }
176    } else {
177      return big()->erase(elem);
178    }
179  }
180
181 private:
182  // Size         rep_
183  // -------------------------------------------------------------------------
184  // 0            0
185  // 1            The pointer itself (bottom bits == 00)
186  // large        Pointer to a BigRep (bottom bits == 01)
187  uintptr_t rep_;
188
189  bool isbig() const { return (rep_ & 0x3) == 1; }
190  BigRep* big() const {
191    DCHECK(isbig());
192    return reinterpret_cast<BigRep*>(rep_ - 1);
193  }
194
195  void MakeBig() {
196    DCHECK(!isbig());
197    BigRep* big = new BigRep;
198    if (rep_ != 0) {
199      big->insert(reinterpret_cast<T>(rep_));
200    }
201    rep_ = reinterpret_cast<uintptr_t>(big) + 0x1;
202  }
203};
204
205}  // namespace gtl
206}  // namespace tensorflow
207
208#endif  // TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
209