1// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file. See the AUTHORS file for names of contributors.
4
5#include "db/skiplist.h"
6#include <set>
7#include "leveldb/env.h"
8#include "util/arena.h"
9#include "util/hash.h"
10#include "util/random.h"
11#include "util/testharness.h"
12
13namespace leveldb {
14
15typedef uint64_t Key;
16
17struct Comparator {
18  int operator()(const Key& a, const Key& b) const {
19    if (a < b) {
20      return -1;
21    } else if (a > b) {
22      return +1;
23    } else {
24      return 0;
25    }
26  }
27};
28
29class SkipTest { };
30
31TEST(SkipTest, Empty) {
32  Arena arena;
33  Comparator cmp;
34  SkipList<Key, Comparator> list(cmp, &arena);
35  ASSERT_TRUE(!list.Contains(10));
36
37  SkipList<Key, Comparator>::Iterator iter(&list);
38  ASSERT_TRUE(!iter.Valid());
39  iter.SeekToFirst();
40  ASSERT_TRUE(!iter.Valid());
41  iter.Seek(100);
42  ASSERT_TRUE(!iter.Valid());
43  iter.SeekToLast();
44  ASSERT_TRUE(!iter.Valid());
45}
46
47TEST(SkipTest, InsertAndLookup) {
48  const int N = 2000;
49  const int R = 5000;
50  Random rnd(1000);
51  std::set<Key> keys;
52  Arena arena;
53  Comparator cmp;
54  SkipList<Key, Comparator> list(cmp, &arena);
55  for (int i = 0; i < N; i++) {
56    Key key = rnd.Next() % R;
57    if (keys.insert(key).second) {
58      list.Insert(key);
59    }
60  }
61
62  for (int i = 0; i < R; i++) {
63    if (list.Contains(i)) {
64      ASSERT_EQ(keys.count(i), 1);
65    } else {
66      ASSERT_EQ(keys.count(i), 0);
67    }
68  }
69
70  // Simple iterator tests
71  {
72    SkipList<Key, Comparator>::Iterator iter(&list);
73    ASSERT_TRUE(!iter.Valid());
74
75    iter.Seek(0);
76    ASSERT_TRUE(iter.Valid());
77    ASSERT_EQ(*(keys.begin()), iter.key());
78
79    iter.SeekToFirst();
80    ASSERT_TRUE(iter.Valid());
81    ASSERT_EQ(*(keys.begin()), iter.key());
82
83    iter.SeekToLast();
84    ASSERT_TRUE(iter.Valid());
85    ASSERT_EQ(*(keys.rbegin()), iter.key());
86  }
87
88  // Forward iteration test
89  for (int i = 0; i < R; i++) {
90    SkipList<Key, Comparator>::Iterator iter(&list);
91    iter.Seek(i);
92
93    // Compare against model iterator
94    std::set<Key>::iterator model_iter = keys.lower_bound(i);
95    for (int j = 0; j < 3; j++) {
96      if (model_iter == keys.end()) {
97        ASSERT_TRUE(!iter.Valid());
98        break;
99      } else {
100        ASSERT_TRUE(iter.Valid());
101        ASSERT_EQ(*model_iter, iter.key());
102        ++model_iter;
103        iter.Next();
104      }
105    }
106  }
107
108  // Backward iteration test
109  {
110    SkipList<Key, Comparator>::Iterator iter(&list);
111    iter.SeekToLast();
112
113    // Compare against model iterator
114    for (std::set<Key>::reverse_iterator model_iter = keys.rbegin();
115         model_iter != keys.rend();
116         ++model_iter) {
117      ASSERT_TRUE(iter.Valid());
118      ASSERT_EQ(*model_iter, iter.key());
119      iter.Prev();
120    }
121    ASSERT_TRUE(!iter.Valid());
122  }
123}
124
125// We want to make sure that with a single writer and multiple
126// concurrent readers (with no synchronization other than when a
127// reader's iterator is created), the reader always observes all the
128// data that was present in the skip list when the iterator was
129// constructor.  Because insertions are happening concurrently, we may
130// also observe new values that were inserted since the iterator was
131// constructed, but we should never miss any values that were present
132// at iterator construction time.
133//
134// We generate multi-part keys:
135//     <key,gen,hash>
136// where:
137//     key is in range [0..K-1]
138//     gen is a generation number for key
139//     hash is hash(key,gen)
140//
141// The insertion code picks a random key, sets gen to be 1 + the last
142// generation number inserted for that key, and sets hash to Hash(key,gen).
143//
144// At the beginning of a read, we snapshot the last inserted
145// generation number for each key.  We then iterate, including random
146// calls to Next() and Seek().  For every key we encounter, we
147// check that it is either expected given the initial snapshot or has
148// been concurrently added since the iterator started.
149class ConcurrentTest {
150 private:
151  static const uint32_t K = 4;
152
153  static uint64_t key(Key key) { return (key >> 40); }
154  static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
155  static uint64_t hash(Key key) { return key & 0xff; }
156
157  static uint64_t HashNumbers(uint64_t k, uint64_t g) {
158    uint64_t data[2] = { k, g };
159    return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
160  }
161
162  static Key MakeKey(uint64_t k, uint64_t g) {
163    assert(sizeof(Key) == sizeof(uint64_t));
164    assert(k <= K);  // We sometimes pass K to seek to the end of the skiplist
165    assert(g <= 0xffffffffu);
166    return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
167  }
168
169  static bool IsValidKey(Key k) {
170    return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
171  }
172
173  static Key RandomTarget(Random* rnd) {
174    switch (rnd->Next() % 10) {
175      case 0:
176        // Seek to beginning
177        return MakeKey(0, 0);
178      case 1:
179        // Seek to end
180        return MakeKey(K, 0);
181      default:
182        // Seek to middle
183        return MakeKey(rnd->Next() % K, 0);
184    }
185  }
186
187  // Per-key generation
188  struct State {
189    port::AtomicPointer generation[K];
190    void Set(int k, intptr_t v) {
191      generation[k].Release_Store(reinterpret_cast<void*>(v));
192    }
193    intptr_t Get(int k) {
194      return reinterpret_cast<intptr_t>(generation[k].Acquire_Load());
195    }
196
197    State() {
198      for (int k = 0; k < K; k++) {
199        Set(k, 0);
200      }
201    }
202  };
203
204  // Current state of the test
205  State current_;
206
207  Arena arena_;
208
209  // SkipList is not protected by mu_.  We just use a single writer
210  // thread to modify it.
211  SkipList<Key, Comparator> list_;
212
213 public:
214  ConcurrentTest() : list_(Comparator(), &arena_) { }
215
216  // REQUIRES: External synchronization
217  void WriteStep(Random* rnd) {
218    const uint32_t k = rnd->Next() % K;
219    const intptr_t g = current_.Get(k) + 1;
220    const Key key = MakeKey(k, g);
221    list_.Insert(key);
222    current_.Set(k, g);
223  }
224
225  void ReadStep(Random* rnd) {
226    // Remember the initial committed state of the skiplist.
227    State initial_state;
228    for (int k = 0; k < K; k++) {
229      initial_state.Set(k, current_.Get(k));
230    }
231
232    Key pos = RandomTarget(rnd);
233    SkipList<Key, Comparator>::Iterator iter(&list_);
234    iter.Seek(pos);
235    while (true) {
236      Key current;
237      if (!iter.Valid()) {
238        current = MakeKey(K, 0);
239      } else {
240        current = iter.key();
241        ASSERT_TRUE(IsValidKey(current)) << current;
242      }
243      ASSERT_LE(pos, current) << "should not go backwards";
244
245      // Verify that everything in [pos,current) was not present in
246      // initial_state.
247      while (pos < current) {
248        ASSERT_LT(key(pos), K) << pos;
249
250        // Note that generation 0 is never inserted, so it is ok if
251        // <*,0,*> is missing.
252        ASSERT_TRUE((gen(pos) == 0) ||
253                    (gen(pos) > initial_state.Get(key(pos)))
254                    ) << "key: " << key(pos)
255                      << "; gen: " << gen(pos)
256                      << "; initgen: "
257                      << initial_state.Get(key(pos));
258
259        // Advance to next key in the valid key space
260        if (key(pos) < key(current)) {
261          pos = MakeKey(key(pos) + 1, 0);
262        } else {
263          pos = MakeKey(key(pos), gen(pos) + 1);
264        }
265      }
266
267      if (!iter.Valid()) {
268        break;
269      }
270
271      if (rnd->Next() % 2) {
272        iter.Next();
273        pos = MakeKey(key(pos), gen(pos) + 1);
274      } else {
275        Key new_target = RandomTarget(rnd);
276        if (new_target > pos) {
277          pos = new_target;
278          iter.Seek(new_target);
279        }
280      }
281    }
282  }
283};
284const uint32_t ConcurrentTest::K;
285
286// Simple test that does single-threaded testing of the ConcurrentTest
287// scaffolding.
288TEST(SkipTest, ConcurrentWithoutThreads) {
289  ConcurrentTest test;
290  Random rnd(test::RandomSeed());
291  for (int i = 0; i < 10000; i++) {
292    test.ReadStep(&rnd);
293    test.WriteStep(&rnd);
294  }
295}
296
297class TestState {
298 public:
299  ConcurrentTest t_;
300  int seed_;
301  port::AtomicPointer quit_flag_;
302
303  enum ReaderState {
304    STARTING,
305    RUNNING,
306    DONE
307  };
308
309  explicit TestState(int s)
310      : seed_(s),
311        quit_flag_(NULL),
312        state_(STARTING),
313        state_cv_(&mu_) {}
314
315  void Wait(ReaderState s) {
316    mu_.Lock();
317    while (state_ != s) {
318      state_cv_.Wait();
319    }
320    mu_.Unlock();
321  }
322
323  void Change(ReaderState s) {
324    mu_.Lock();
325    state_ = s;
326    state_cv_.Signal();
327    mu_.Unlock();
328  }
329
330 private:
331  port::Mutex mu_;
332  ReaderState state_;
333  port::CondVar state_cv_;
334};
335
336static void ConcurrentReader(void* arg) {
337  TestState* state = reinterpret_cast<TestState*>(arg);
338  Random rnd(state->seed_);
339  int64_t reads = 0;
340  state->Change(TestState::RUNNING);
341  while (!state->quit_flag_.Acquire_Load()) {
342    state->t_.ReadStep(&rnd);
343    ++reads;
344  }
345  state->Change(TestState::DONE);
346}
347
348static void RunConcurrent(int run) {
349  const int seed = test::RandomSeed() + (run * 100);
350  Random rnd(seed);
351  const int N = 1000;
352  const int kSize = 1000;
353  for (int i = 0; i < N; i++) {
354    if ((i % 100) == 0) {
355      fprintf(stderr, "Run %d of %d\n", i, N);
356    }
357    TestState state(seed + 1);
358    Env::Default()->Schedule(ConcurrentReader, &state);
359    state.Wait(TestState::RUNNING);
360    for (int i = 0; i < kSize; i++) {
361      state.t_.WriteStep(&rnd);
362    }
363    state.quit_flag_.Release_Store(&state);  // Any non-NULL arg will do
364    state.Wait(TestState::DONE);
365  }
366}
367
368TEST(SkipTest, Concurrent1) { RunConcurrent(1); }
369TEST(SkipTest, Concurrent2) { RunConcurrent(2); }
370TEST(SkipTest, Concurrent3) { RunConcurrent(3); }
371TEST(SkipTest, Concurrent4) { RunConcurrent(4); }
372TEST(SkipTest, Concurrent5) { RunConcurrent(5); }
373
374}  // namespace leveldb
375
376int main(int argc, char** argv) {
377  return leveldb::test::RunAllTests();
378}
379