1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#define EIGEN_USE_THREADS
12#include "main.h"
13#include <Eigen/CXX11/ThreadPool>
14
15// Visual studio doesn't implement a rand_r() function since its
16// implementation of rand() is already thread safe
17int rand_reentrant(unsigned int* s) {
18#ifdef EIGEN_COMP_MSVC_STRICT
19  EIGEN_UNUSED_VARIABLE(s);
20  return rand();
21#else
22  return rand_r(s);
23#endif
24}
25
26static void test_basic_eventcount()
27{
28  MaxSizeVector<EventCount::Waiter> waiters(1);
29  waiters.resize(1);
30  EventCount ec(waiters);
31  EventCount::Waiter& w = waiters[0];
32  ec.Notify(false);
33  ec.Prewait(&w);
34  ec.Notify(true);
35  ec.CommitWait(&w);
36  ec.Prewait(&w);
37  ec.CancelWait(&w);
38}
39
40// Fake bounded counter-based queue.
41struct TestQueue {
42  std::atomic<int> val_;
43  static const int kQueueSize = 10;
44
45  TestQueue() : val_() {}
46
47  ~TestQueue() { VERIFY_IS_EQUAL(val_.load(), 0); }
48
49  bool Push() {
50    int val = val_.load(std::memory_order_relaxed);
51    for (;;) {
52      VERIFY_GE(val, 0);
53      VERIFY_LE(val, kQueueSize);
54      if (val == kQueueSize) return false;
55      if (val_.compare_exchange_weak(val, val + 1, std::memory_order_relaxed))
56        return true;
57    }
58  }
59
60  bool Pop() {
61    int val = val_.load(std::memory_order_relaxed);
62    for (;;) {
63      VERIFY_GE(val, 0);
64      VERIFY_LE(val, kQueueSize);
65      if (val == 0) return false;
66      if (val_.compare_exchange_weak(val, val - 1, std::memory_order_relaxed))
67        return true;
68    }
69  }
70
71  bool Empty() { return val_.load(std::memory_order_relaxed) == 0; }
72};
73
74const int TestQueue::kQueueSize;
75
76// A number of producers send messages to a set of consumers using a set of
77// fake queues. Ensure that it does not crash, consumers don't deadlock and
78// number of blocked and unblocked threads match.
79static void test_stress_eventcount()
80{
81  const int kThreads = std::thread::hardware_concurrency();
82  static const int kEvents = 1 << 16;
83  static const int kQueues = 10;
84
85  MaxSizeVector<EventCount::Waiter> waiters(kThreads);
86  waiters.resize(kThreads);
87  EventCount ec(waiters);
88  TestQueue queues[kQueues];
89
90  std::vector<std::unique_ptr<std::thread>> producers;
91  for (int i = 0; i < kThreads; i++) {
92    producers.emplace_back(new std::thread([&ec, &queues]() {
93      unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
94      for (int j = 0; j < kEvents; j++) {
95        unsigned idx = rand_reentrant(&rnd) % kQueues;
96        if (queues[idx].Push()) {
97          ec.Notify(false);
98          continue;
99        }
100        EIGEN_THREAD_YIELD();
101        j--;
102      }
103    }));
104  }
105
106  std::vector<std::unique_ptr<std::thread>> consumers;
107  for (int i = 0; i < kThreads; i++) {
108    consumers.emplace_back(new std::thread([&ec, &queues, &waiters, i]() {
109      EventCount::Waiter& w = waiters[i];
110      unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
111      for (int j = 0; j < kEvents; j++) {
112        unsigned idx = rand_reentrant(&rnd) % kQueues;
113        if (queues[idx].Pop()) continue;
114        j--;
115        ec.Prewait(&w);
116        bool empty = true;
117        for (int q = 0; q < kQueues; q++) {
118          if (!queues[q].Empty()) {
119            empty = false;
120            break;
121          }
122        }
123        if (!empty) {
124          ec.CancelWait(&w);
125          continue;
126        }
127        ec.CommitWait(&w);
128      }
129    }));
130  }
131
132  for (int i = 0; i < kThreads; i++) {
133    producers[i]->join();
134    consumers[i]->join();
135  }
136}
137
138void test_cxx11_eventcount()
139{
140  CALL_SUBTEST(test_basic_eventcount());
141  CALL_SUBTEST(test_stress_eventcount());
142}
143