1/**
2 * @file  rwlock_test.c
3 *
4 * @brief Multithreaded test program that triggers various access patterns
5 *        without triggering any race conditions.
6 */
7
8
9#define _GNU_SOURCE 1
10
11#include <assert.h>
12#include <limits.h>  /* PTHREAD_STACK_MIN */
13#include <pthread.h>
14#include <stdio.h>
15#include <stdlib.h>  /* malloc() */
16#include <string.h>  /* strerror() */
17#include <unistd.h>  /* getopt() */
18
19static int s_num_threads = 10;
20static int s_num_iterations = 1000;
21static pthread_mutex_t s_mutex;
22static long long s_grand_sum; /* protected by s_mutex. */
23static pthread_rwlock_t s_rwlock;
24static int s_counter; /* protected by s_rwlock. */
25
26static void* thread_func(void* arg)
27{
28  int i, r;
29  int sum1 = 0, sum2 = 0;
30
31  for (i = s_num_iterations; i > 0; i--)
32  {
33    r = pthread_rwlock_rdlock(&s_rwlock);
34    assert(! r);
35    sum1 += s_counter;
36    r = pthread_rwlock_unlock(&s_rwlock);
37    assert(! r);
38    r = pthread_rwlock_wrlock(&s_rwlock);
39    assert(! r);
40    sum2 += s_counter++;
41    r = pthread_rwlock_unlock(&s_rwlock);
42    assert(! r);
43  }
44
45  pthread_mutex_lock(&s_mutex);
46  s_grand_sum += sum2;
47  pthread_mutex_unlock(&s_mutex);
48
49  return 0;
50}
51
52int main(int argc, char** argv)
53{
54  pthread_attr_t attr;
55  pthread_t* tid;
56  int threads_created;
57  int optchar;
58  int err;
59  int i;
60  int expected_counter;
61  long long expected_grand_sum;
62
63  while ((optchar = getopt(argc, argv, "i:t:")) != EOF)
64  {
65    switch (optchar)
66    {
67    case 'i':
68      s_num_iterations = atoi(optarg);
69      break;
70    case 't':
71      s_num_threads = atoi(optarg);
72      break;
73    default:
74      fprintf(stderr, "Error: unknown option '%c'.\n", optchar);
75      return 1;
76    }
77  }
78
79  pthread_mutex_init(&s_mutex, NULL);
80  pthread_rwlock_init(&s_rwlock, NULL);
81
82  pthread_attr_init(&attr);
83  err = pthread_attr_setstacksize(&attr, PTHREAD_STACK_MIN + 4096);
84  assert(err == 0);
85
86  tid = calloc(s_num_threads, sizeof(*tid));
87  threads_created = 0;
88  for (i = 0; i < s_num_threads; i++)
89  {
90    err = pthread_create(&tid[i], &attr, thread_func, 0);
91    if (err)
92      printf("failed to create thread %d: %s\n", i, strerror(err));
93    else
94      threads_created++;
95  }
96
97  pthread_attr_destroy(&attr);
98
99  for (i = 0; i < s_num_threads; i++)
100  {
101    if (tid[i])
102      pthread_join(tid[i], 0);
103  }
104  free(tid);
105
106  expected_counter = threads_created * s_num_iterations;
107  fprintf(stderr, "s_counter - expected_counter = %d\n",
108          s_counter - expected_counter);
109  expected_grand_sum = 1ULL * expected_counter * (expected_counter - 1) / 2;
110  fprintf(stderr, "s_grand_sum - expected_grand_sum = %lld\n",
111          s_grand_sum - expected_grand_sum);
112  fprintf(stderr, "Finished.\n");
113
114  return 0;
115}
116