1import os
2import unittest
3import random
4from test import test_support
5thread = test_support.import_module('thread')
6import time
7import sys
8import weakref
9
10from test import lock_tests
11
12NUMTASKS = 10
13NUMTRIPS = 3
14
15
16_print_mutex = thread.allocate_lock()
17
18def verbose_print(arg):
19    """Helper function for printing out debugging output."""
20    if test_support.verbose:
21        with _print_mutex:
22            print arg
23
24
25class BasicThreadTest(unittest.TestCase):
26
27    def setUp(self):
28        self.done_mutex = thread.allocate_lock()
29        self.done_mutex.acquire()
30        self.running_mutex = thread.allocate_lock()
31        self.random_mutex = thread.allocate_lock()
32        self.created = 0
33        self.running = 0
34        self.next_ident = 0
35
36
37class ThreadRunningTests(BasicThreadTest):
38
39    def newtask(self):
40        with self.running_mutex:
41            self.next_ident += 1
42            verbose_print("creating task %s" % self.next_ident)
43            thread.start_new_thread(self.task, (self.next_ident,))
44            self.created += 1
45            self.running += 1
46
47    def task(self, ident):
48        with self.random_mutex:
49            delay = random.random() / 10000.0
50        verbose_print("task %s will run for %sus" % (ident, round(delay*1e6)))
51        time.sleep(delay)
52        verbose_print("task %s done" % ident)
53        with self.running_mutex:
54            self.running -= 1
55            if self.created == NUMTASKS and self.running == 0:
56                self.done_mutex.release()
57
58    def test_starting_threads(self):
59        # Basic test for thread creation.
60        for i in range(NUMTASKS):
61            self.newtask()
62        verbose_print("waiting for tasks to complete...")
63        self.done_mutex.acquire()
64        verbose_print("all tasks done")
65
66    def test_stack_size(self):
67        # Various stack size tests.
68        self.assertEqual(thread.stack_size(), 0, "initial stack size is not 0")
69
70        thread.stack_size(0)
71        self.assertEqual(thread.stack_size(), 0, "stack_size not reset to default")
72
73        if os.name not in ("nt", "os2", "posix"):
74            return
75
76        tss_supported = True
77        try:
78            thread.stack_size(4096)
79        except ValueError:
80            verbose_print("caught expected ValueError setting "
81                            "stack_size(4096)")
82        except thread.error:
83            tss_supported = False
84            verbose_print("platform does not support changing thread stack "
85                            "size")
86
87        if tss_supported:
88            fail_msg = "stack_size(%d) failed - should succeed"
89            for tss in (262144, 0x100000, 0):
90                thread.stack_size(tss)
91                self.assertEqual(thread.stack_size(), tss, fail_msg % tss)
92                verbose_print("successfully set stack_size(%d)" % tss)
93
94            for tss in (262144, 0x100000):
95                verbose_print("trying stack_size = (%d)" % tss)
96                self.next_ident = 0
97                self.created = 0
98                for i in range(NUMTASKS):
99                    self.newtask()
100
101                verbose_print("waiting for all tasks to complete")
102                self.done_mutex.acquire()
103                verbose_print("all tasks done")
104
105            thread.stack_size(0)
106
107    def test__count(self):
108        # Test the _count() function.
109        orig = thread._count()
110        mut = thread.allocate_lock()
111        mut.acquire()
112        started = []
113        def task():
114            started.append(None)
115            mut.acquire()
116            mut.release()
117        thread.start_new_thread(task, ())
118        while not started:
119            time.sleep(0.01)
120        self.assertEqual(thread._count(), orig + 1)
121        # Allow the task to finish.
122        mut.release()
123        # The only reliable way to be sure that the thread ended from the
124        # interpreter's point of view is to wait for the function object to be
125        # destroyed.
126        done = []
127        wr = weakref.ref(task, lambda _: done.append(None))
128        del task
129        while not done:
130            time.sleep(0.01)
131        self.assertEqual(thread._count(), orig)
132
133    def test_save_exception_state_on_error(self):
134        # See issue #14474
135        def task():
136            started.release()
137            raise SyntaxError
138        def mywrite(self, *args):
139            try:
140                raise ValueError
141            except ValueError:
142                pass
143            real_write(self, *args)
144        c = thread._count()
145        started = thread.allocate_lock()
146        with test_support.captured_output("stderr") as stderr:
147            real_write = stderr.write
148            stderr.write = mywrite
149            started.acquire()
150            thread.start_new_thread(task, ())
151            started.acquire()
152            while thread._count() > c:
153                time.sleep(0.01)
154        self.assertIn("Traceback", stderr.getvalue())
155
156
157class Barrier:
158    def __init__(self, num_threads):
159        self.num_threads = num_threads
160        self.waiting = 0
161        self.checkin_mutex  = thread.allocate_lock()
162        self.checkout_mutex = thread.allocate_lock()
163        self.checkout_mutex.acquire()
164
165    def enter(self):
166        self.checkin_mutex.acquire()
167        self.waiting = self.waiting + 1
168        if self.waiting == self.num_threads:
169            self.waiting = self.num_threads - 1
170            self.checkout_mutex.release()
171            return
172        self.checkin_mutex.release()
173
174        self.checkout_mutex.acquire()
175        self.waiting = self.waiting - 1
176        if self.waiting == 0:
177            self.checkin_mutex.release()
178            return
179        self.checkout_mutex.release()
180
181
182class BarrierTest(BasicThreadTest):
183
184    def test_barrier(self):
185        self.bar = Barrier(NUMTASKS)
186        self.running = NUMTASKS
187        for i in range(NUMTASKS):
188            thread.start_new_thread(self.task2, (i,))
189        verbose_print("waiting for tasks to end")
190        self.done_mutex.acquire()
191        verbose_print("tasks done")
192
193    def task2(self, ident):
194        for i in range(NUMTRIPS):
195            if ident == 0:
196                # give it a good chance to enter the next
197                # barrier before the others are all out
198                # of the current one
199                delay = 0
200            else:
201                with self.random_mutex:
202                    delay = random.random() / 10000.0
203            verbose_print("task %s will run for %sus" %
204                          (ident, round(delay * 1e6)))
205            time.sleep(delay)
206            verbose_print("task %s entering %s" % (ident, i))
207            self.bar.enter()
208            verbose_print("task %s leaving barrier" % ident)
209        with self.running_mutex:
210            self.running -= 1
211            # Must release mutex before releasing done, else the main thread can
212            # exit and set mutex to None as part of global teardown; then
213            # mutex.release() raises AttributeError.
214            finished = self.running == 0
215        if finished:
216            self.done_mutex.release()
217
218
219class LockTests(lock_tests.LockTests):
220    locktype = thread.allocate_lock
221
222
223class TestForkInThread(unittest.TestCase):
224    def setUp(self):
225        self.read_fd, self.write_fd = os.pipe()
226
227    @unittest.skipIf(sys.platform.startswith('win'),
228                     "This test is only appropriate for POSIX-like systems.")
229    @test_support.reap_threads
230    def test_forkinthread(self):
231        def thread1():
232            try:
233                pid = os.fork() # fork in a thread
234            except RuntimeError:
235                sys.exit(0) # exit the child
236
237            if pid == 0: # child
238                os.close(self.read_fd)
239                os.write(self.write_fd, "OK")
240                sys.exit(0)
241            else: # parent
242                os.close(self.write_fd)
243
244        thread.start_new_thread(thread1, ())
245        self.assertEqual(os.read(self.read_fd, 2), "OK",
246                         "Unable to fork() in thread")
247
248    def tearDown(self):
249        try:
250            os.close(self.read_fd)
251        except OSError:
252            pass
253
254        try:
255            os.close(self.write_fd)
256        except OSError:
257            pass
258
259
260def test_main():
261    test_support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
262                              TestForkInThread)
263
264if __name__ == "__main__":
265    test_main()
266