1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Unittests for reraiser_thread.py."""
6
7import threading
8import unittest
9
10from devil.utils import reraiser_thread
11from devil.utils import watchdog_timer
12
13
14class TestException(Exception):
15  pass
16
17
18class TestReraiserThread(unittest.TestCase):
19  """Tests for reraiser_thread.ReraiserThread."""
20
21  def testNominal(self):
22    result = [None, None]
23
24    def f(a, b=None):
25      result[0] = a
26      result[1] = b
27
28    thread = reraiser_thread.ReraiserThread(f, [1], {'b': 2})
29    thread.start()
30    thread.join()
31    self.assertEqual(result[0], 1)
32    self.assertEqual(result[1], 2)
33
34  def testRaise(self):
35    def f():
36      raise TestException
37
38    thread = reraiser_thread.ReraiserThread(f)
39    thread.start()
40    thread.join()
41    with self.assertRaises(TestException):
42      thread.ReraiseIfException()
43
44
45class TestReraiserThreadGroup(unittest.TestCase):
46  """Tests for reraiser_thread.ReraiserThreadGroup."""
47
48  def testInit(self):
49    ran = [False] * 5
50
51    def f(i):
52      ran[i] = True
53
54    group = reraiser_thread.ReraiserThreadGroup(
55      [reraiser_thread.ReraiserThread(f, args=[i]) for i in range(5)])
56    group.StartAll()
57    group.JoinAll()
58    for v in ran:
59      self.assertTrue(v)
60
61  def testAdd(self):
62    ran = [False] * 5
63
64    def f(i):
65      ran[i] = True
66
67    group = reraiser_thread.ReraiserThreadGroup()
68    for i in xrange(5):
69      group.Add(reraiser_thread.ReraiserThread(f, args=[i]))
70    group.StartAll()
71    group.JoinAll()
72    for v in ran:
73      self.assertTrue(v)
74
75  def testJoinRaise(self):
76    def f():
77      raise TestException
78    group = reraiser_thread.ReraiserThreadGroup(
79      [reraiser_thread.ReraiserThread(f) for _ in xrange(5)])
80    group.StartAll()
81    with self.assertRaises(TestException):
82      group.JoinAll()
83
84  def testJoinTimeout(self):
85    def f():
86      pass
87    event = threading.Event()
88
89    def g():
90      event.wait()
91    group = reraiser_thread.ReraiserThreadGroup(
92        [reraiser_thread.ReraiserThread(g),
93         reraiser_thread.ReraiserThread(f)])
94    group.StartAll()
95    with self.assertRaises(reraiser_thread.TimeoutError):
96      group.JoinAll(watchdog_timer.WatchdogTimer(0.01))
97    event.set()
98
99
100class TestRunAsync(unittest.TestCase):
101  """Tests for reraiser_thread.RunAsync."""
102
103  def testNoArgs(self):
104    results = reraiser_thread.RunAsync([])
105    self.assertEqual([], results)
106
107  def testOneArg(self):
108    results = reraiser_thread.RunAsync([lambda: 1])
109    self.assertEqual([1], results)
110
111  def testTwoArgs(self):
112    a, b = reraiser_thread.RunAsync((lambda: 1, lambda: 2))
113    self.assertEqual(1, a)
114    self.assertEqual(2, b)
115
116if __name__ == '__main__':
117  unittest.main()
118