1/*
2 * Test program that illustrates how to annotate a smart pointer
3 * implementation.  In a multithreaded program the following is relevant when
4 * working with smart pointers:
5 * - whether or not the objects pointed at are shared over threads.
6 * - whether or not the methods of the objects pointed at are thread-safe.
7 * - whether or not the smart pointer objects are shared over threads.
8 * - whether or not the smart pointer object itself is thread-safe.
9 *
10 * Most smart pointer implemenations are not thread-safe
11 * (e.g. boost::shared_ptr<>, tr1::shared_ptr<> and the smart_ptr<>
12 * implementation below). This means that it is not safe to modify a shared
13 * pointer object that is shared over threads without proper synchronization.
14 *
15 * Even for non-thread-safe smart pointers it is possible to have different
16 * threads access the same object via smart pointers without triggering data
17 * races on the smart pointer objects.
18 *
19 * A smart pointer implementation guarantees that the destructor of the object
20 * pointed at is invoked after the last smart pointer that points to that
21 * object has been destroyed or reset. Data race detection tools cannot detect
22 * this ordering without explicit annotation for smart pointers that track
23 * references without invoking synchronization operations recognized by data
24 * race detection tools.
25 */
26
27
28#include <cassert>     // assert()
29#include <climits>     // PTHREAD_STACK_MIN
30#include <iostream>    // std::cerr
31#include <stdlib.h>    // atoi()
32#include <vector>
33#ifdef _WIN32
34#include <process.h>   // _beginthreadex()
35#include <windows.h>   // CRITICAL_SECTION
36#else
37#include <pthread.h>   // pthread_mutex_t
38#endif
39#include "unified_annotations.h"
40
41
42static bool s_enable_annotations;
43
44
45#ifdef _WIN32
46
47class AtomicInt32
48{
49public:
50  AtomicInt32(const int value = 0) : m_value(value) { }
51  ~AtomicInt32() { }
52  LONG operator++() { return InterlockedIncrement(&m_value); }
53  LONG operator--() { return InterlockedDecrement(&m_value); }
54
55private:
56  volatile LONG m_value;
57};
58
59class Mutex
60{
61public:
62  Mutex() : m_mutex()
63  { InitializeCriticalSection(&m_mutex); }
64  ~Mutex()
65  { DeleteCriticalSection(&m_mutex); }
66  void Lock()
67  { EnterCriticalSection(&m_mutex); }
68  void Unlock()
69  { LeaveCriticalSection(&m_mutex); }
70
71private:
72  CRITICAL_SECTION m_mutex;
73};
74
75class Thread
76{
77public:
78  Thread() : m_thread(INVALID_HANDLE_VALUE) { }
79  ~Thread() { }
80  void Create(void* (*pf)(void*), void* arg)
81  {
82    WrapperArgs* wrapper_arg_p = new WrapperArgs(pf, arg);
83    m_thread = reinterpret_cast<HANDLE>(_beginthreadex(NULL, 0, wrapper,
84						       wrapper_arg_p, 0, NULL));
85  }
86  void Join()
87  { WaitForSingleObject(m_thread, INFINITE); }
88
89private:
90  struct WrapperArgs
91  {
92    WrapperArgs(void* (*pf)(void*), void* arg) : m_pf(pf), m_arg(arg) { }
93
94    void* (*m_pf)(void*);
95    void* m_arg;
96  };
97  static unsigned int __stdcall wrapper(void* arg)
98  {
99    WrapperArgs* wrapper_arg_p = reinterpret_cast<WrapperArgs*>(arg);
100    WrapperArgs wa = *wrapper_arg_p;
101    delete wrapper_arg_p;
102    return reinterpret_cast<unsigned>((wa.m_pf)(wa.m_arg));
103  }
104  HANDLE m_thread;
105};
106
107#else // _WIN32
108
109class AtomicInt32
110{
111public:
112  AtomicInt32(const int value = 0) : m_value(value) { }
113  ~AtomicInt32() { }
114  int operator++() { return __sync_add_and_fetch(&m_value, 1); }
115  int operator--() { return __sync_sub_and_fetch(&m_value, 1); }
116private:
117  volatile int m_value;
118};
119
120class Mutex
121{
122public:
123  Mutex() : m_mutex()
124  { pthread_mutex_init(&m_mutex, NULL); }
125  ~Mutex()
126  { pthread_mutex_destroy(&m_mutex); }
127  void Lock()
128  { pthread_mutex_lock(&m_mutex); }
129  void Unlock()
130  { pthread_mutex_unlock(&m_mutex); }
131
132private:
133  pthread_mutex_t m_mutex;
134};
135
136class Thread
137{
138public:
139  Thread() : m_tid() { }
140  ~Thread() { }
141  void Create(void* (*pf)(void*), void* arg)
142  {
143    pthread_attr_t attr;
144    pthread_attr_init(&attr);
145    pthread_attr_setstacksize(&attr, PTHREAD_STACK_MIN + 4096);
146    pthread_create(&m_tid, &attr, pf, arg);
147    pthread_attr_destroy(&attr);
148  }
149  void Join()
150  { pthread_join(m_tid, NULL); }
151private:
152  pthread_t m_tid;
153};
154
155#endif // !defined(_WIN32)
156
157
158template<class T>
159class smart_ptr
160{
161public:
162  typedef AtomicInt32 counter_t;
163
164  template <typename Q> friend class smart_ptr;
165
166  explicit smart_ptr()
167    : m_ptr(NULL), m_count_ptr(NULL)
168  { }
169
170  explicit smart_ptr(T* const pT)
171    : m_ptr(NULL), m_count_ptr(NULL)
172  {
173    set(pT, pT ? new counter_t(0) : NULL);
174  }
175
176  template <typename Q>
177  explicit smart_ptr(Q* const q)
178    : m_ptr(NULL), m_count_ptr(NULL)
179  {
180    set(q, q ? new counter_t(0) : NULL);
181  }
182
183  ~smart_ptr()
184  {
185    set(NULL, NULL);
186  }
187
188  smart_ptr(const smart_ptr<T>& sp)
189    : m_ptr(NULL), m_count_ptr(NULL)
190  {
191    set(sp.m_ptr, sp.m_count_ptr);
192  }
193
194  template <typename Q>
195  smart_ptr(const smart_ptr<Q>& sp)
196    : m_ptr(NULL), m_count_ptr(NULL)
197  {
198    set(sp.m_ptr, sp.m_count_ptr);
199  }
200
201  smart_ptr& operator=(const smart_ptr<T>& sp)
202  {
203    set(sp.m_ptr, sp.m_count_ptr);
204    return *this;
205  }
206
207  smart_ptr& operator=(T* const p)
208  {
209    set(p, p ? new counter_t(0) : NULL);
210    return *this;
211  }
212
213  template <typename Q>
214  smart_ptr& operator=(Q* const q)
215  {
216    set(q, q ? new counter_t(0) : NULL);
217    return *this;
218  }
219
220  T* operator->() const
221  {
222    assert(m_ptr);
223    return m_ptr;
224  }
225
226  T& operator*() const
227  {
228    assert(m_ptr);
229    return *m_ptr;
230  }
231
232private:
233  void set(T* const pT, counter_t* const count_ptr)
234  {
235    if (m_ptr != pT)
236    {
237      if (m_count_ptr)
238      {
239	if (s_enable_annotations)
240	  U_ANNOTATE_HAPPENS_BEFORE(m_count_ptr);
241	if (--(*m_count_ptr) == 0)
242	{
243	  if (s_enable_annotations)
244	    U_ANNOTATE_HAPPENS_AFTER(m_count_ptr);
245	  delete m_ptr;
246	  m_ptr = NULL;
247	  delete m_count_ptr;
248	  m_count_ptr = NULL;
249	}
250      }
251      m_ptr = pT;
252      m_count_ptr = count_ptr;
253      if (count_ptr)
254	++(*m_count_ptr);
255    }
256  }
257
258  T*         m_ptr;
259  counter_t* m_count_ptr;
260};
261
262class counter
263{
264public:
265  counter()
266    : m_mutex(), m_count()
267  { }
268  ~counter()
269  {
270    // Data race detection tools that do not recognize the
271    // ANNOTATE_HAPPENS_BEFORE() / ANNOTATE_HAPPENS_AFTER() annotations in the
272    // smart_ptr<> implementation will report that the assignment below
273    // triggers a data race.
274    m_count = -1;
275  }
276  int get() const
277  {
278    int result;
279    m_mutex.Lock();
280    result = m_count;
281    m_mutex.Unlock();
282    return result;
283  }
284  int post_increment()
285  {
286    int result;
287    m_mutex.Lock();
288    result = m_count++;
289    m_mutex.Unlock();
290    return result;
291  }
292
293private:
294  mutable Mutex m_mutex;
295  int           m_count;
296};
297
298static void* thread_func(void* arg)
299{
300  smart_ptr<counter>* pp = reinterpret_cast<smart_ptr<counter>*>(arg);
301  (*pp)->post_increment();
302  *pp = NULL;
303  delete pp;
304  return NULL;
305}
306
307int main(int argc, char** argv)
308{
309  const int nthreads = std::max(argc > 1 ? atoi(argv[1]) : 1, 1);
310  const int iterations = std::max(argc > 2 ? atoi(argv[2]) : 1, 1);
311  s_enable_annotations = argc > 3 ? !!atoi(argv[3]) : true;
312
313  for (int j = 0; j < iterations; ++j)
314  {
315    std::vector<Thread> T(nthreads);
316    smart_ptr<counter> p(new counter);
317    p->post_increment();
318    for (std::vector<Thread>::iterator q = T.begin(); q != T.end(); q++)
319      q->Create(thread_func, new smart_ptr<counter>(p));
320    {
321      // Avoid that counter.m_mutex introduces a false ordering on the
322      // counter.m_count accesses.
323      const timespec delay = { 0, 100 * 1000 * 1000 };
324      nanosleep(&delay, 0);
325    }
326    p = NULL;
327    for (std::vector<Thread>::iterator q = T.begin(); q != T.end(); q++)
328      q->Join();
329  }
330  std::cerr << "Done.\n";
331  return 0;
332}
333