1/*-------------------------------------------------------------------------
2 * drawElements C++ Base Library
3 * -----------------------------
4 *
5 * Copyright 2015 The Android Open Source Project
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *      http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 *//*!
20 * \file
21 * \brief Cross-thread barrier.
22 *//*--------------------------------------------------------------------*/
23
24#include "deSpinBarrier.hpp"
25#include "deThread.hpp"
26#include "deRandom.hpp"
27#include "deInt32.h"
28
29#include <vector>
30
31namespace de
32{
33
34SpinBarrier::SpinBarrier (deInt32 numThreads)
35	: m_numCores	(deGetNumAvailableLogicalCores())
36	, m_numThreads	(numThreads)
37	, m_numEntered	(0)
38	, m_numLeaving	(0)
39	, m_numRemoved	(0)
40{
41	DE_ASSERT(numThreads > 0);
42}
43
44SpinBarrier::~SpinBarrier (void)
45{
46	DE_ASSERT(m_numEntered == 0 && m_numLeaving == 0);
47}
48
49void SpinBarrier::reset (deUint32 numThreads)
50{
51	// If last threads were removed, m_numEntered > 0 && m_numRemoved > 0
52	DE_ASSERT(m_numLeaving == 0);
53	DE_ASSERT(numThreads > 0);
54	m_numThreads = numThreads;
55	m_numEntered = 0;
56	m_numLeaving = 0;
57	m_numRemoved = 0;
58}
59
60inline SpinBarrier::WaitMode getWaitMode (SpinBarrier::WaitMode requested, deUint32 numCores, deInt32 numThreads)
61{
62	if (requested == SpinBarrier::WAIT_MODE_AUTO)
63		return ((deUint32)numThreads <= numCores) ? SpinBarrier::WAIT_MODE_BUSY : SpinBarrier::WAIT_MODE_YIELD;
64	else
65		return requested;
66}
67
68inline void wait (SpinBarrier::WaitMode mode)
69{
70	DE_ASSERT(mode == SpinBarrier::WAIT_MODE_YIELD || mode == SpinBarrier::WAIT_MODE_BUSY);
71
72	if (mode == SpinBarrier::WAIT_MODE_YIELD)
73		deYield();
74}
75
76void SpinBarrier::sync (WaitMode requestedMode)
77{
78	const deInt32	cachedNumThreads	= m_numThreads;
79	const WaitMode	waitMode			= getWaitMode(requestedMode, m_numCores, cachedNumThreads);
80
81	deMemoryReadWriteFence();
82
83	// m_numEntered must not be touched until all threads have had
84	// a chance to observe it being 0.
85	if (m_numLeaving > 0)
86	{
87		for (;;)
88		{
89			if (m_numLeaving == 0)
90				break;
91
92			wait(waitMode);
93		}
94	}
95
96	// If m_numRemoved > 0, m_numThreads will decrease. If m_numThreads is decreased
97	// just after atomicOp and before comparison, the branch could be taken by multiple
98	// threads. Since m_numThreads only changes if all threads are inside the spinbarrier,
99	// cached value at snapshotted at the beginning of the function will be equal for
100	// all threads.
101	if (deAtomicIncrement32(&m_numEntered) == cachedNumThreads)
102	{
103		// Release all waiting threads. Since this thread has not been removed, m_numLeaving will
104		// be >= 1 until m_numLeaving is decremented at the end of this function.
105		m_numThreads -= m_numRemoved;
106		m_numLeaving  = m_numThreads;
107		m_numRemoved  = 0;
108
109		deMemoryReadWriteFence();
110		m_numEntered  = 0;
111	}
112	else
113	{
114		for (;;)
115		{
116			if (m_numEntered == 0)
117				break;
118
119			wait(waitMode);
120		}
121	}
122
123	deAtomicDecrement32(&m_numLeaving);
124	deMemoryReadWriteFence();
125}
126
127void SpinBarrier::removeThread (WaitMode requestedMode)
128{
129	const deInt32	cachedNumThreads	= m_numThreads;
130	const WaitMode	waitMode			= getWaitMode(requestedMode, m_numCores, cachedNumThreads);
131
132	// Wait for other threads exiting previous barrier
133	if (m_numLeaving > 0)
134	{
135		for (;;)
136		{
137			if (m_numLeaving == 0)
138				break;
139
140			wait(waitMode);
141		}
142	}
143
144	// Ask for last thread entering barrier to adjust thread count
145	deAtomicIncrement32(&m_numRemoved);
146
147	// See sync() - use cached value
148	if (deAtomicIncrement32(&m_numEntered) == cachedNumThreads)
149	{
150		// Release all waiting threads.
151		m_numThreads -= m_numRemoved;
152		m_numLeaving  = m_numThreads;
153		m_numRemoved  = 0;
154
155		deMemoryReadWriteFence();
156		m_numEntered  = 0;
157	}
158}
159
160namespace
161{
162
163void singleThreadTest (SpinBarrier::WaitMode mode)
164{
165	SpinBarrier barrier(1);
166
167	barrier.sync(mode);
168	barrier.sync(mode);
169	barrier.sync(mode);
170}
171
172class TestThread : public de::Thread
173{
174public:
175	TestThread (SpinBarrier& barrier, volatile deInt32* sharedVar, int numThreads, int threadNdx)
176		: m_barrier		(barrier)
177		, m_sharedVar	(sharedVar)
178		, m_numThreads	(numThreads)
179		, m_threadNdx	(threadNdx)
180		, m_busyOk		((deUint32)m_numThreads <= deGetNumAvailableLogicalCores())
181	{
182	}
183
184	void run (void)
185	{
186		const int	numIters	= 10000;
187		de::Random	rnd			(deInt32Hash(m_numThreads) ^ deInt32Hash(m_threadNdx));
188
189		for (int iterNdx = 0; iterNdx < numIters; iterNdx++)
190		{
191			// Phase 1: count up
192			deAtomicIncrement32(m_sharedVar);
193
194			// Verify
195			m_barrier.sync(getWaitMode(rnd));
196
197			DE_TEST_ASSERT(*m_sharedVar == m_numThreads);
198
199			m_barrier.sync(getWaitMode(rnd));
200
201			// Phase 2: count down
202			deAtomicDecrement32(m_sharedVar);
203
204			// Verify
205			m_barrier.sync(getWaitMode(rnd));
206
207			DE_TEST_ASSERT(*m_sharedVar == 0);
208
209			m_barrier.sync(getWaitMode(rnd));
210		}
211	}
212
213private:
214	SpinBarrier&			m_barrier;
215	volatile deInt32* const	m_sharedVar;
216	const int				m_numThreads;
217	const int				m_threadNdx;
218	const bool				m_busyOk;
219
220	SpinBarrier::WaitMode getWaitMode (de::Random& rnd)
221	{
222		static const SpinBarrier::WaitMode	s_allModes[]	=
223		{
224			SpinBarrier::WAIT_MODE_YIELD,
225			SpinBarrier::WAIT_MODE_AUTO,
226			SpinBarrier::WAIT_MODE_BUSY,
227		};
228		const int							numModes		= DE_LENGTH_OF_ARRAY(s_allModes) - (m_busyOk ? 0 : 1);
229
230		return rnd.choose<SpinBarrier::WaitMode>(DE_ARRAY_BEGIN(s_allModes), DE_ARRAY_BEGIN(s_allModes) + numModes);
231	}
232};
233
234void multiThreadTest (int numThreads)
235{
236	SpinBarrier					barrier		(numThreads);
237	volatile deInt32			sharedVar	= 0;
238	std::vector<TestThread*>	threads		(numThreads, static_cast<TestThread*>(DE_NULL));
239
240	for (int ndx = 0; ndx < numThreads; ndx++)
241	{
242		threads[ndx] = new TestThread(barrier, &sharedVar, numThreads, ndx);
243		DE_TEST_ASSERT(threads[ndx]);
244		threads[ndx]->start();
245	}
246
247	for (int ndx = 0; ndx < numThreads; ndx++)
248	{
249		threads[ndx]->join();
250		delete threads[ndx];
251	}
252
253	DE_TEST_ASSERT(sharedVar == 0);
254}
255
256void singleThreadRemoveTest (SpinBarrier::WaitMode mode)
257{
258	SpinBarrier barrier(3);
259
260	barrier.removeThread(mode);
261	barrier.removeThread(mode);
262	barrier.sync(mode);
263	barrier.removeThread(mode);
264
265	barrier.reset(1);
266	barrier.sync(mode);
267
268	barrier.reset(2);
269	barrier.removeThread(mode);
270	barrier.sync(mode);
271}
272
273class TestExitThread : public de::Thread
274{
275public:
276	TestExitThread (SpinBarrier& barrier, int numThreads, int threadNdx, SpinBarrier::WaitMode waitMode)
277		: m_barrier		(barrier)
278		, m_numThreads	(numThreads)
279		, m_threadNdx	(threadNdx)
280		, m_waitMode	(waitMode)
281	{
282	}
283
284	void run (void)
285	{
286		const int	numIters	= 10000;
287		de::Random	rnd			(deInt32Hash(m_numThreads) ^ deInt32Hash(m_threadNdx) ^ deInt32Hash((deInt32)m_waitMode));
288		const int	invExitProb	= 1000;
289
290		for (int iterNdx = 0; iterNdx < numIters; iterNdx++)
291		{
292			if (rnd.getInt(0, invExitProb) == 0)
293			{
294				m_barrier.removeThread(m_waitMode);
295				break;
296			}
297			else
298				m_barrier.sync(m_waitMode);
299		}
300	}
301
302private:
303	SpinBarrier&				m_barrier;
304	const int					m_numThreads;
305	const int					m_threadNdx;
306	const SpinBarrier::WaitMode	m_waitMode;
307};
308
309void multiThreadRemoveTest (int numThreads, SpinBarrier::WaitMode waitMode)
310{
311	SpinBarrier						barrier		(numThreads);
312	std::vector<TestExitThread*>	threads		(numThreads, static_cast<TestExitThread*>(DE_NULL));
313
314	for (int ndx = 0; ndx < numThreads; ndx++)
315	{
316		threads[ndx] = new TestExitThread(barrier, numThreads, ndx, waitMode);
317		DE_TEST_ASSERT(threads[ndx]);
318		threads[ndx]->start();
319	}
320
321	for (int ndx = 0; ndx < numThreads; ndx++)
322	{
323		threads[ndx]->join();
324		delete threads[ndx];
325	}
326}
327
328} // anonymous
329
330void SpinBarrier_selfTest (void)
331{
332	singleThreadTest(SpinBarrier::WAIT_MODE_YIELD);
333	singleThreadTest(SpinBarrier::WAIT_MODE_BUSY);
334	singleThreadTest(SpinBarrier::WAIT_MODE_AUTO);
335	multiThreadTest(1);
336	multiThreadTest(2);
337	multiThreadTest(4);
338	multiThreadTest(8);
339	multiThreadTest(16);
340
341	singleThreadRemoveTest(SpinBarrier::WAIT_MODE_YIELD);
342	singleThreadRemoveTest(SpinBarrier::WAIT_MODE_BUSY);
343	singleThreadRemoveTest(SpinBarrier::WAIT_MODE_AUTO);
344	multiThreadRemoveTest(1, SpinBarrier::WAIT_MODE_BUSY);
345	multiThreadRemoveTest(2, SpinBarrier::WAIT_MODE_AUTO);
346	multiThreadRemoveTest(4, SpinBarrier::WAIT_MODE_AUTO);
347	multiThreadRemoveTest(8, SpinBarrier::WAIT_MODE_AUTO);
348	multiThreadRemoveTest(16, SpinBarrier::WAIT_MODE_AUTO);
349	multiThreadRemoveTest(1, SpinBarrier::WAIT_MODE_YIELD);
350	multiThreadRemoveTest(2, SpinBarrier::WAIT_MODE_YIELD);
351	multiThreadRemoveTest(4, SpinBarrier::WAIT_MODE_YIELD);
352	multiThreadRemoveTest(8, SpinBarrier::WAIT_MODE_YIELD);
353	multiThreadRemoveTest(16, SpinBarrier::WAIT_MODE_YIELD);
354}
355
356} // de
357