1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Unittests for reraiser_thread.py."""
6
7import threading
8import unittest
9
10import reraiser_thread
11import watchdog_timer
12
13
14class TestException(Exception):
15  pass
16
17
18class TestReraiserThread(unittest.TestCase):
19  """Tests for reraiser_thread.ReraiserThread."""
20  def testNominal(self):
21    result = [None, None]
22
23    def f(a, b=None):
24      result[0] = a
25      result[1] = b
26
27    thread = reraiser_thread.ReraiserThread(f, [1], {'b': 2})
28    thread.start()
29    thread.join()
30    self.assertEqual(result[0], 1)
31    self.assertEqual(result[1], 2)
32
33  def testRaise(self):
34    def f():
35      raise TestException
36
37    thread = reraiser_thread.ReraiserThread(f)
38    thread.start()
39    thread.join()
40    with self.assertRaises(TestException):
41      thread.ReraiseIfException()
42
43
44class TestReraiserThreadGroup(unittest.TestCase):
45  """Tests for reraiser_thread.ReraiserThreadGroup."""
46  def testInit(self):
47    ran = [False] * 5
48    def f(i):
49      ran[i] = True
50
51    group = reraiser_thread.ReraiserThreadGroup(
52      [reraiser_thread.ReraiserThread(f, args=[i]) for i in range(5)])
53    group.StartAll()
54    group.JoinAll()
55    for v in ran:
56      self.assertTrue(v)
57
58  def testAdd(self):
59    ran = [False] * 5
60    def f(i):
61      ran[i] = True
62
63    group = reraiser_thread.ReraiserThreadGroup()
64    for i in xrange(5):
65      group.Add(reraiser_thread.ReraiserThread(f, args=[i]))
66    group.StartAll()
67    group.JoinAll()
68    for v in ran:
69      self.assertTrue(v)
70
71  def testJoinRaise(self):
72    def f():
73      raise TestException
74    group = reraiser_thread.ReraiserThreadGroup(
75      [reraiser_thread.ReraiserThread(f) for _ in xrange(5)])
76    group.StartAll()
77    with self.assertRaises(TestException):
78      group.JoinAll()
79
80  def testJoinTimeout(self):
81    def f():
82      pass
83    event = threading.Event()
84    def g():
85      event.wait()
86    group = reraiser_thread.ReraiserThreadGroup(
87        [reraiser_thread.ReraiserThread(g),
88         reraiser_thread.ReraiserThread(f)])
89    group.StartAll()
90    with self.assertRaises(reraiser_thread.TimeoutError):
91      group.JoinAll(watchdog_timer.WatchdogTimer(0.01))
92    event.set()
93
94
95if __name__ == '__main__':
96  unittest.main()
97