test_dispatcher_unittest.py revision 5d1f7b1de12d16ceb2c938c56701a3e8bfa558f7
1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Unittests for test_dispatcher.py."""
6# pylint: disable=R0201
7# pylint: disable=W0212
8
9import os
10import sys
11import unittest
12
13sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)),
14                os.pardir, os.pardir))
15
16# Mock out android_commands.GetAttachedDevices().
17from pylib import android_commands
18android_commands.GetAttachedDevices = lambda: ['0', '1']
19from pylib import constants
20from pylib.base import base_test_result
21from pylib.base import test_dispatcher
22from pylib.utils import watchdog_timer
23
24
25
26class TestException(Exception):
27  pass
28
29
30class MockRunner(object):
31  """A mock TestRunner."""
32  def __init__(self, device='0', shard_index=0):
33    self.device = device
34    self.shard_index = shard_index
35    self.setups = 0
36    self.teardowns = 0
37
38  def RunTest(self, test):
39    results = base_test_result.TestRunResults()
40    results.AddResult(
41        base_test_result.BaseTestResult(test, base_test_result.ResultType.PASS))
42    return (results, None)
43
44  def SetUp(self):
45    self.setups += 1
46
47  def TearDown(self):
48    self.teardowns += 1
49
50
51class MockRunnerFail(MockRunner):
52  def RunTest(self, test):
53    results = base_test_result.TestRunResults()
54    results.AddResult(
55        base_test_result.BaseTestResult(test, base_test_result.ResultType.FAIL))
56    return (results, test)
57
58
59class MockRunnerFailTwice(MockRunner):
60  def __init__(self, device='0', shard_index=0):
61    super(MockRunnerFailTwice, self).__init__(device, shard_index)
62    self._fails = 0
63
64  def RunTest(self, test):
65    self._fails += 1
66    results = base_test_result.TestRunResults()
67    if self._fails <= 2:
68      results.AddResult(base_test_result.BaseTestResult(
69          test, base_test_result.ResultType.FAIL))
70      return (results, test)
71    else:
72      results.AddResult(base_test_result.BaseTestResult(
73          test, base_test_result.ResultType.PASS))
74      return (results, None)
75
76
77class MockRunnerException(MockRunner):
78  def RunTest(self, test):
79    raise TestException
80
81
82class TestFunctions(unittest.TestCase):
83  """Tests test_dispatcher._RunTestsFromQueue."""
84  @staticmethod
85  def _RunTests(mock_runner, tests):
86    results = []
87    tests = test_dispatcher._TestCollection(
88        [test_dispatcher._Test(t) for t in tests])
89    test_dispatcher._RunTestsFromQueue(mock_runner, tests, results,
90                                       watchdog_timer.WatchdogTimer(None), 2)
91    run_results = base_test_result.TestRunResults()
92    for r in results:
93      run_results.AddTestRunResults(r)
94    return run_results
95
96  def testRunTestsFromQueue(self):
97    results = TestFunctions._RunTests(MockRunner(), ['a', 'b'])
98    self.assertEqual(len(results.GetPass()), 2)
99    self.assertEqual(len(results.GetNotPass()), 0)
100
101  def testRunTestsFromQueueRetry(self):
102    results = TestFunctions._RunTests(MockRunnerFail(), ['a', 'b'])
103    self.assertEqual(len(results.GetPass()), 0)
104    self.assertEqual(len(results.GetFail()), 2)
105
106  def testRunTestsFromQueueFailTwice(self):
107    results = TestFunctions._RunTests(MockRunnerFailTwice(), ['a', 'b'])
108    self.assertEqual(len(results.GetPass()), 2)
109    self.assertEqual(len(results.GetNotPass()), 0)
110
111  def testSetUp(self):
112    runners = []
113    counter = test_dispatcher._ThreadSafeCounter()
114    test_dispatcher._SetUp(MockRunner, '0', runners, counter)
115    self.assertEqual(len(runners), 1)
116    self.assertEqual(runners[0].setups, 1)
117
118  def testThreadSafeCounter(self):
119    counter = test_dispatcher._ThreadSafeCounter()
120    for i in xrange(5):
121      self.assertEqual(counter.GetAndIncrement(), i)
122
123
124class TestThreadGroupFunctions(unittest.TestCase):
125  """Tests test_dispatcher._RunAllTests and test_dispatcher._CreateRunners."""
126  def setUp(self):
127    self.tests = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
128    shared_test_collection = test_dispatcher._TestCollection(
129        [test_dispatcher._Test(t) for t in self.tests])
130    self.test_collection_factory = lambda: shared_test_collection
131
132  def testCreate(self):
133    runners = test_dispatcher._CreateRunners(MockRunner, ['0', '1'])
134    for runner in runners:
135      self.assertEqual(runner.setups, 1)
136    self.assertEqual(set([r.device for r in runners]),
137                     set(['0', '1']))
138    self.assertEqual(set([r.shard_index for r in runners]),
139                     set([0, 1]))
140
141  def testRun(self):
142    runners = [MockRunner('0'), MockRunner('1')]
143    results, exit_code = test_dispatcher._RunAllTests(
144        runners, self.test_collection_factory, 0)
145    self.assertEqual(len(results.GetPass()), len(self.tests))
146    self.assertEqual(exit_code, 0)
147
148  def testTearDown(self):
149    runners = [MockRunner('0'), MockRunner('1')]
150    test_dispatcher._TearDownRunners(runners)
151    for runner in runners:
152      self.assertEqual(runner.teardowns, 1)
153
154  def testRetry(self):
155    runners = test_dispatcher._CreateRunners(MockRunnerFail, ['0', '1'])
156    results, exit_code = test_dispatcher._RunAllTests(
157        runners, self.test_collection_factory, 0)
158    self.assertEqual(len(results.GetFail()), len(self.tests))
159    self.assertEqual(exit_code, constants.ERROR_EXIT_CODE)
160
161  def testReraise(self):
162    runners = test_dispatcher._CreateRunners(MockRunnerException, ['0', '1'])
163    with self.assertRaises(TestException):
164      test_dispatcher._RunAllTests(runners, self.test_collection_factory, 0)
165
166
167class TestShard(unittest.TestCase):
168  """Tests test_dispatcher.RunTests with sharding."""
169  @staticmethod
170  def _RunShard(runner_factory):
171    return test_dispatcher.RunTests(
172        ['a', 'b', 'c'], runner_factory, ['0', '1'], shard=True)
173
174  def testShard(self):
175    results, exit_code = TestShard._RunShard(MockRunner)
176    self.assertEqual(len(results.GetPass()), 3)
177    self.assertEqual(exit_code, 0)
178
179  def testFailing(self):
180    results, exit_code = TestShard._RunShard(MockRunnerFail)
181    self.assertEqual(len(results.GetPass()), 0)
182    self.assertEqual(len(results.GetFail()), 3)
183    self.assertEqual(exit_code, constants.ERROR_EXIT_CODE)
184
185  def testNoTests(self):
186    results, exit_code = test_dispatcher.RunTests(
187        [], MockRunner, ['0', '1'], shard=True)
188    self.assertEqual(len(results.GetAll()), 0)
189    self.assertEqual(exit_code, constants.ERROR_EXIT_CODE)
190
191  def testTestsRemainWithAllDevicesOffline(self):
192    attached_devices = android_commands.GetAttachedDevices
193    android_commands.GetAttachedDevices = lambda: []
194    try:
195      with self.assertRaises(AssertionError):
196        _results, _exit_code = TestShard._RunShard(MockRunner)
197    finally:
198      android_commands.GetAttachedDevices = attached_devices
199
200
201class TestReplicate(unittest.TestCase):
202  """Tests test_dispatcher.RunTests with replication."""
203  @staticmethod
204  def _RunReplicate(runner_factory):
205    return test_dispatcher.RunTests(
206        ['a', 'b', 'c'], runner_factory, ['0', '1'], shard=False)
207
208  def testReplicate(self):
209    results, exit_code = TestReplicate._RunReplicate(MockRunner)
210    # We expect 6 results since each test should have been run on every device
211    self.assertEqual(len(results.GetPass()), 6)
212    self.assertEqual(exit_code, 0)
213
214  def testFailing(self):
215    results, exit_code = TestReplicate._RunReplicate(MockRunnerFail)
216    self.assertEqual(len(results.GetPass()), 0)
217    self.assertEqual(len(results.GetFail()), 6)
218    self.assertEqual(exit_code, constants.ERROR_EXIT_CODE)
219
220  def testNoTests(self):
221    results, exit_code = test_dispatcher.RunTests(
222        [], MockRunner, ['0', '1'], shard=False)
223    self.assertEqual(len(results.GetAll()), 0)
224    self.assertEqual(exit_code, constants.ERROR_EXIT_CODE)
225
226
227if __name__ == '__main__':
228  unittest.main()
229