hash_set.h revision 88b6b051cc1c92b40537941c68061fc0d3b46a9f
1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef ART_RUNTIME_BASE_HASH_SET_H_
18#define ART_RUNTIME_BASE_HASH_SET_H_
19
20#include <functional>
21#include <memory>
22#include <stdint.h>
23#include <utility>
24
25#include "bit_utils.h"
26#include "logging.h"
27
28namespace art {
29
30// Returns true if an item is empty.
31template <class T>
32class DefaultEmptyFn {
33 public:
34  void MakeEmpty(T& item) const {
35    item = T();
36  }
37  bool IsEmpty(const T& item) const {
38    return item == T();
39  }
40};
41
42template <class T>
43class DefaultEmptyFn<T*> {
44 public:
45  void MakeEmpty(T*& item) const {
46    item = nullptr;
47  }
48  bool IsEmpty(const T*& item) const {
49    return item == nullptr;
50  }
51};
52
53// Low memory version of a hash set, uses less memory than std::unordered_set since elements aren't
54// boxed. Uses linear probing to resolve collisions.
55// EmptyFn needs to implement two functions MakeEmpty(T& item) and IsEmpty(const T& item).
56// TODO: We could get rid of this requirement by using a bitmap, though maybe this would be slower
57// and more complicated.
58template <class T, class EmptyFn = DefaultEmptyFn<T>, class HashFn = std::hash<T>,
59    class Pred = std::equal_to<T>, class Alloc = std::allocator<T>>
60class HashSet {
61  template <class Elem, class HashSetType>
62  class BaseIterator {
63   public:
64    BaseIterator(const BaseIterator&) = default;
65    BaseIterator(BaseIterator&&) = default;
66    BaseIterator(HashSetType* hash_set, size_t index) : index_(index), hash_set_(hash_set) {
67    }
68    BaseIterator& operator=(const BaseIterator&) = default;
69    BaseIterator& operator=(BaseIterator&&) = default;
70
71    bool operator==(const BaseIterator& other) const {
72      return hash_set_ == other.hash_set_ && this->index_ == other.index_;
73    }
74
75    bool operator!=(const BaseIterator& other) const {
76      return !(*this == other);
77    }
78
79    BaseIterator operator++() {  // Value after modification.
80      this->index_ = this->NextNonEmptySlot(this->index_, hash_set_);
81      return *this;
82    }
83
84    BaseIterator operator++(int) {
85      Iterator temp = *this;
86      this->index_ = this->NextNonEmptySlot(this->index_, hash_set_);
87      return temp;
88    }
89
90    Elem& operator*() const {
91      DCHECK(!hash_set_->IsFreeSlot(this->index_));
92      return hash_set_->ElementForIndex(this->index_);
93    }
94
95    Elem* operator->() const {
96      return &**this;
97    }
98
99    // TODO: Operator -- --(int)
100
101   private:
102    size_t index_;
103    HashSetType* hash_set_;
104
105    size_t NextNonEmptySlot(size_t index, const HashSet* hash_set) const {
106      const size_t num_buckets = hash_set->NumBuckets();
107      DCHECK_LT(index, num_buckets);
108      do {
109        ++index;
110      } while (index < num_buckets && hash_set->IsFreeSlot(index));
111      return index;
112    }
113
114    friend class HashSet;
115  };
116
117 public:
118  static constexpr double kDefaultMinLoadFactor = 0.5;
119  static constexpr double kDefaultMaxLoadFactor = 0.9;
120  static constexpr size_t kMinBuckets = 1000;
121
122  typedef BaseIterator<T, HashSet> Iterator;
123  typedef BaseIterator<const T, const HashSet> ConstIterator;
124
125  // If we don't own the data, this will create a new array which owns the data.
126  void Clear() {
127    DeallocateStorage();
128    AllocateStorage(1);
129    num_elements_ = 0;
130    elements_until_expand_ = 0;
131  }
132
133  HashSet() : num_elements_(0), num_buckets_(0), owns_data_(false), data_(nullptr),
134      min_load_factor_(kDefaultMinLoadFactor), max_load_factor_(kDefaultMaxLoadFactor) {
135    Clear();
136  }
137
138  HashSet(const HashSet& other) : num_elements_(0), num_buckets_(0), owns_data_(false),
139      data_(nullptr) {
140    *this = other;
141  }
142
143  HashSet(HashSet&& other) : num_elements_(0), num_buckets_(0), owns_data_(false),
144      data_(nullptr) {
145    *this = std::move(other);
146  }
147
148  // Construct from existing data.
149  // Read from a block of memory, if make_copy_of_data is false, then data_ points to within the
150  // passed in ptr_.
151  HashSet(const uint8_t* ptr, bool make_copy_of_data, size_t* read_count) {
152    uint64_t temp;
153    size_t offset = 0;
154    offset = ReadFromBytes(ptr, offset, &temp);
155    num_elements_ = static_cast<uint64_t>(temp);
156    offset = ReadFromBytes(ptr, offset, &temp);
157    num_buckets_ = static_cast<uint64_t>(temp);
158    CHECK_LE(num_elements_, num_buckets_);
159    offset = ReadFromBytes(ptr, offset, &temp);
160    elements_until_expand_ = static_cast<uint64_t>(temp);
161    offset = ReadFromBytes(ptr, offset, &min_load_factor_);
162    offset = ReadFromBytes(ptr, offset, &max_load_factor_);
163    if (!make_copy_of_data) {
164      owns_data_ = false;
165      data_ = const_cast<T*>(reinterpret_cast<const T*>(ptr + offset));
166      offset += sizeof(*data_) * num_buckets_;
167    } else {
168      AllocateStorage(num_buckets_);
169      // Write elements, not that this may not be safe for cross compilation if the elements are
170      // pointer sized.
171      for (size_t i = 0; i < num_buckets_; ++i) {
172        offset = ReadFromBytes(ptr, offset, &data_[i]);
173      }
174    }
175    // Caller responsible for aligning.
176    *read_count = offset;
177  }
178
179  // Returns how large the table is after being written. If target is null, then no writing happens
180  // but the size is still returned. Target must be 8 byte aligned.
181  size_t WriteToMemory(uint8_t* ptr) {
182    size_t offset = 0;
183    offset = WriteToBytes(ptr, offset, static_cast<uint64_t>(num_elements_));
184    offset = WriteToBytes(ptr, offset, static_cast<uint64_t>(num_buckets_));
185    offset = WriteToBytes(ptr, offset, static_cast<uint64_t>(elements_until_expand_));
186    offset = WriteToBytes(ptr, offset, min_load_factor_);
187    offset = WriteToBytes(ptr, offset, max_load_factor_);
188    // Write elements, not that this may not be safe for cross compilation if the elements are
189    // pointer sized.
190    for (size_t i = 0; i < num_buckets_; ++i) {
191      offset = WriteToBytes(ptr, offset, data_[i]);
192    }
193    // Caller responsible for aligning.
194    return offset;
195  }
196
197  ~HashSet() {
198    DeallocateStorage();
199  }
200
201  HashSet& operator=(HashSet&& other) {
202    std::swap(data_, other.data_);
203    std::swap(num_buckets_, other.num_buckets_);
204    std::swap(num_elements_, other.num_elements_);
205    std::swap(elements_until_expand_, other.elements_until_expand_);
206    std::swap(min_load_factor_, other.min_load_factor_);
207    std::swap(max_load_factor_, other.max_load_factor_);
208    std::swap(owns_data_, other.owns_data_);
209    return *this;
210  }
211
212  HashSet& operator=(const HashSet& other) {
213    DeallocateStorage();
214    AllocateStorage(other.NumBuckets());
215    for (size_t i = 0; i < num_buckets_; ++i) {
216      ElementForIndex(i) = other.data_[i];
217    }
218    num_elements_ = other.num_elements_;
219    elements_until_expand_ = other.elements_until_expand_;
220    min_load_factor_ = other.min_load_factor_;
221    max_load_factor_ = other.max_load_factor_;
222    return *this;
223  }
224
225  // Lower case for c++11 for each.
226  Iterator begin() {
227    Iterator ret(this, 0);
228    if (num_buckets_ != 0 && IsFreeSlot(ret.index_)) {
229      ++ret;  // Skip all the empty slots.
230    }
231    return ret;
232  }
233
234  // Lower case for c++11 for each. const version.
235  ConstIterator begin() const {
236    ConstIterator ret(this, 0);
237    if (num_buckets_ != 0 && IsFreeSlot(ret.index_)) {
238      ++ret;  // Skip all the empty slots.
239    }
240    return ret;
241  }
242
243  // Lower case for c++11 for each.
244  Iterator end() {
245    return Iterator(this, NumBuckets());
246  }
247
248  // Lower case for c++11 for each. const version.
249  ConstIterator end() const {
250    return ConstIterator(this, NumBuckets());
251  }
252
253  bool Empty() {
254    return Size() == 0;
255  }
256
257  // Erase algorithm:
258  // Make an empty slot where the iterator is pointing.
259  // Scan forwards until we hit another empty slot.
260  // If an element in between doesn't rehash to the range from the current empty slot to the
261  // iterator. It must be before the empty slot, in that case we can move it to the empty slot
262  // and set the empty slot to be the location we just moved from.
263  // Relies on maintaining the invariant that there's no empty slots from the 'ideal' index of an
264  // element to its actual location/index.
265  Iterator Erase(Iterator it) {
266    // empty_index is the index that will become empty.
267    size_t empty_index = it.index_;
268    DCHECK(!IsFreeSlot(empty_index));
269    size_t next_index = empty_index;
270    bool filled = false;  // True if we filled the empty index.
271    while (true) {
272      next_index = NextIndex(next_index);
273      T& next_element = ElementForIndex(next_index);
274      // If the next element is empty, we are done. Make sure to clear the current empty index.
275      if (emptyfn_.IsEmpty(next_element)) {
276        emptyfn_.MakeEmpty(ElementForIndex(empty_index));
277        break;
278      }
279      // Otherwise try to see if the next element can fill the current empty index.
280      const size_t next_hash = hashfn_(next_element);
281      // Calculate the ideal index, if it is within empty_index + 1 to next_index then there is
282      // nothing we can do.
283      size_t next_ideal_index = IndexForHash(next_hash);
284      // Loop around if needed for our check.
285      size_t unwrapped_next_index = next_index;
286      if (unwrapped_next_index < empty_index) {
287        unwrapped_next_index += NumBuckets();
288      }
289      // Loop around if needed for our check.
290      size_t unwrapped_next_ideal_index = next_ideal_index;
291      if (unwrapped_next_ideal_index < empty_index) {
292        unwrapped_next_ideal_index += NumBuckets();
293      }
294      if (unwrapped_next_ideal_index <= empty_index ||
295          unwrapped_next_ideal_index > unwrapped_next_index) {
296        // If the target index isn't within our current range it must have been probed from before
297        // the empty index.
298        ElementForIndex(empty_index) = std::move(next_element);
299        filled = true;  // TODO: Optimize
300        empty_index = next_index;
301      }
302    }
303    --num_elements_;
304    // If we didn't fill the slot then we need go to the next non free slot.
305    if (!filled) {
306      ++it;
307    }
308    return it;
309  }
310
311  // Find an element, returns end() if not found.
312  // Allows custom key (K) types, example of when this is useful:
313  // Set of Class* sorted by name, want to find a class with a name but can't allocate a dummy
314  // object in the heap for performance solution.
315  template <typename K>
316  Iterator Find(const K& key) {
317    return FindWithHash(key, hashfn_(key));
318  }
319
320  template <typename K>
321  ConstIterator Find(const K& key) const {
322    return FindWithHash(key, hashfn_(key));
323  }
324
325  template <typename K>
326  Iterator FindWithHash(const K& key, size_t hash) {
327    return Iterator(this, FindIndex(key, hash));
328  }
329
330  template <typename K>
331  ConstIterator FindWithHash(const K& key, size_t hash) const {
332    return ConstIterator(this, FindIndex(key, hash));
333  }
334
335  // Insert an element, allows duplicates.
336  void Insert(const T& element) {
337    InsertWithHash(element, hashfn_(element));
338  }
339
340  void InsertWithHash(const T& element, size_t hash) {
341    DCHECK_EQ(hash, hashfn_(element));
342    if (num_elements_ >= elements_until_expand_) {
343      Expand();
344      DCHECK_LT(num_elements_, elements_until_expand_);
345    }
346    const size_t index = FirstAvailableSlot(IndexForHash(hash));
347    data_[index] = element;
348    ++num_elements_;
349  }
350
351  size_t Size() const {
352    return num_elements_;
353  }
354
355  void ShrinkToMaximumLoad() {
356    Resize(Size() / max_load_factor_);
357  }
358
359  // To distance that inserted elements were probed. Used for measuring how good hash functions
360  // are.
361  size_t TotalProbeDistance() const {
362    size_t total = 0;
363    for (size_t i = 0; i < NumBuckets(); ++i) {
364      const T& element = ElementForIndex(i);
365      if (!emptyfn_.IsEmpty(element)) {
366        size_t ideal_location = IndexForHash(hashfn_(element));
367        if (ideal_location > i) {
368          total += i + NumBuckets() - ideal_location;
369        } else {
370          total += i - ideal_location;
371        }
372      }
373    }
374    return total;
375  }
376
377  // Calculate the current load factor and return it.
378  double CalculateLoadFactor() const {
379    return static_cast<double>(Size()) / static_cast<double>(NumBuckets());
380  }
381
382  // Make sure that everything reinserts in the right spot. Returns the number of errors.
383  size_t Verify() {
384    size_t errors = 0;
385    for (size_t i = 0; i < num_buckets_; ++i) {
386      T& element = data_[i];
387      if (!emptyfn_.IsEmpty(element)) {
388        T temp;
389        emptyfn_.MakeEmpty(temp);
390        std::swap(temp, element);
391        size_t first_slot = FirstAvailableSlot(IndexForHash(hashfn_(temp)));
392        if (i != first_slot) {
393          LOG(ERROR) << "Element " << i << " should be in slot " << first_slot;
394          ++errors;
395        }
396        std::swap(temp, element);
397      }
398    }
399    return errors;
400  }
401
402 private:
403  T& ElementForIndex(size_t index) {
404    DCHECK_LT(index, NumBuckets());
405    DCHECK(data_ != nullptr);
406    return data_[index];
407  }
408
409  const T& ElementForIndex(size_t index) const {
410    DCHECK_LT(index, NumBuckets());
411    DCHECK(data_ != nullptr);
412    return data_[index];
413  }
414
415  size_t IndexForHash(size_t hash) const {
416    // Protect against undefined behavior (division by zero).
417    if (UNLIKELY(num_buckets_ == 0)) {
418      return 0;
419    }
420    return hash % num_buckets_;
421  }
422
423  size_t NextIndex(size_t index) const {
424    if (UNLIKELY(++index >= num_buckets_)) {
425      DCHECK_EQ(index, NumBuckets());
426      return 0;
427    }
428    return index;
429  }
430
431  // Find the hash table slot for an element, or return NumBuckets() if not found.
432  // This value for not found is important so that Iterator(this, FindIndex(...)) == end().
433  template <typename K>
434  size_t FindIndex(const K& element, size_t hash) const {
435    // Guard against failing to get an element for a non-existing index.
436    if (UNLIKELY(NumBuckets() == 0)) {
437      return 0;
438    }
439    DCHECK_EQ(hashfn_(element), hash);
440    size_t index = IndexForHash(hash);
441    while (true) {
442      const T& slot = ElementForIndex(index);
443      if (emptyfn_.IsEmpty(slot)) {
444        return NumBuckets();
445      }
446      if (pred_(slot, element)) {
447        return index;
448      }
449      index = NextIndex(index);
450    }
451  }
452
453  bool IsFreeSlot(size_t index) const {
454    return emptyfn_.IsEmpty(ElementForIndex(index));
455  }
456
457  size_t NumBuckets() const {
458    return num_buckets_;
459  }
460
461  // Allocate a number of buckets.
462  void AllocateStorage(size_t num_buckets) {
463    num_buckets_ = num_buckets;
464    data_ = allocfn_.allocate(num_buckets_);
465    owns_data_ = true;
466    for (size_t i = 0; i < num_buckets_; ++i) {
467      allocfn_.construct(allocfn_.address(data_[i]));
468      emptyfn_.MakeEmpty(data_[i]);
469    }
470  }
471
472  void DeallocateStorage() {
473    if (owns_data_) {
474      for (size_t i = 0; i < NumBuckets(); ++i) {
475        allocfn_.destroy(allocfn_.address(data_[i]));
476      }
477      if (data_ != nullptr) {
478        allocfn_.deallocate(data_, NumBuckets());
479      }
480      owns_data_ = false;
481    }
482    data_ = nullptr;
483    num_buckets_ = 0;
484  }
485
486  // Expand the set based on the load factors.
487  void Expand() {
488    size_t min_index = static_cast<size_t>(Size() / min_load_factor_);
489    // Resize based on the minimum load factor.
490    Resize(min_index);
491  }
492
493  // Expand / shrink the table to the new specified size.
494  void Resize(size_t new_size) {
495    if (new_size < kMinBuckets) {
496      new_size = kMinBuckets;
497    }
498    DCHECK_GE(new_size, Size());
499    T* const old_data = data_;
500    size_t old_num_buckets = num_buckets_;
501    // Reinsert all of the old elements.
502    const bool owned_data = owns_data_;
503    AllocateStorage(new_size);
504    for (size_t i = 0; i < old_num_buckets; ++i) {
505      T& element = old_data[i];
506      if (!emptyfn_.IsEmpty(element)) {
507        data_[FirstAvailableSlot(IndexForHash(hashfn_(element)))] = std::move(element);
508      }
509      if (owned_data) {
510        allocfn_.destroy(allocfn_.address(element));
511      }
512    }
513    if (owned_data) {
514      allocfn_.deallocate(old_data, old_num_buckets);
515    }
516
517    // When we hit elements_until_expand_, we are at the max load factor and must expand again.
518    elements_until_expand_ = NumBuckets() * max_load_factor_;
519  }
520
521  ALWAYS_INLINE size_t FirstAvailableSlot(size_t index) const {
522    DCHECK_LT(index, NumBuckets());  // Don't try to get a slot out of range.
523    size_t non_empty_count = 0;
524    while (!emptyfn_.IsEmpty(data_[index])) {
525      index = NextIndex(index);
526      non_empty_count++;
527      DCHECK_LE(non_empty_count, NumBuckets());  // Don't loop forever.
528    }
529    return index;
530  }
531
532  // Return new offset.
533  template <typename Elem>
534  static size_t WriteToBytes(uint8_t* ptr, size_t offset, Elem n) {
535    DCHECK_ALIGNED(ptr + offset, sizeof(n));
536    if (ptr != nullptr) {
537      *reinterpret_cast<Elem*>(ptr + offset) = n;
538    }
539    return offset + sizeof(n);
540  }
541
542  template <typename Elem>
543  static size_t ReadFromBytes(const uint8_t* ptr, size_t offset, Elem* out) {
544    DCHECK(ptr != nullptr);
545    DCHECK_ALIGNED(ptr + offset, sizeof(*out));
546    *out = *reinterpret_cast<const Elem*>(ptr + offset);
547    return offset + sizeof(*out);
548  }
549
550  Alloc allocfn_;  // Allocator function.
551  HashFn hashfn_;  // Hashing function.
552  EmptyFn emptyfn_;  // IsEmpty/SetEmpty function.
553  Pred pred_;  // Equals function.
554  size_t num_elements_;  // Number of inserted elements.
555  size_t num_buckets_;  // Number of hash table buckets.
556  size_t elements_until_expand_;  // Maximum number of elements until we expand the table.
557  bool owns_data_;  // If we own data_ and are responsible for freeing it.
558  T* data_;  // Backing storage.
559  double min_load_factor_;
560  double max_load_factor_;
561};
562
563}  // namespace art
564
565#endif  // ART_RUNTIME_BASE_HASH_SET_H_
566