1import sys
2
3from cStringIO import StringIO
4
5import unittest
6
7
8def resultFactory(*_):
9    return unittest.TestResult()
10
11
12class TestSetups(unittest.TestCase):
13
14    def getRunner(self):
15        return unittest.TextTestRunner(resultclass=resultFactory,
16                                          stream=StringIO())
17    def runTests(self, *cases):
18        suite = unittest.TestSuite()
19        for case in cases:
20            tests = unittest.defaultTestLoader.loadTestsFromTestCase(case)
21            suite.addTests(tests)
22
23        runner = self.getRunner()
24
25        # creating a nested suite exposes some potential bugs
26        realSuite = unittest.TestSuite()
27        realSuite.addTest(suite)
28        # adding empty suites to the end exposes potential bugs
29        suite.addTest(unittest.TestSuite())
30        realSuite.addTest(unittest.TestSuite())
31        return runner.run(realSuite)
32
33    def test_setup_class(self):
34        class Test(unittest.TestCase):
35            setUpCalled = 0
36            @classmethod
37            def setUpClass(cls):
38                Test.setUpCalled += 1
39                unittest.TestCase.setUpClass()
40            def test_one(self):
41                pass
42            def test_two(self):
43                pass
44
45        result = self.runTests(Test)
46
47        self.assertEqual(Test.setUpCalled, 1)
48        self.assertEqual(result.testsRun, 2)
49        self.assertEqual(len(result.errors), 0)
50
51    def test_teardown_class(self):
52        class Test(unittest.TestCase):
53            tearDownCalled = 0
54            @classmethod
55            def tearDownClass(cls):
56                Test.tearDownCalled += 1
57                unittest.TestCase.tearDownClass()
58            def test_one(self):
59                pass
60            def test_two(self):
61                pass
62
63        result = self.runTests(Test)
64
65        self.assertEqual(Test.tearDownCalled, 1)
66        self.assertEqual(result.testsRun, 2)
67        self.assertEqual(len(result.errors), 0)
68
69    def test_teardown_class_two_classes(self):
70        class Test(unittest.TestCase):
71            tearDownCalled = 0
72            @classmethod
73            def tearDownClass(cls):
74                Test.tearDownCalled += 1
75                unittest.TestCase.tearDownClass()
76            def test_one(self):
77                pass
78            def test_two(self):
79                pass
80
81        class Test2(unittest.TestCase):
82            tearDownCalled = 0
83            @classmethod
84            def tearDownClass(cls):
85                Test2.tearDownCalled += 1
86                unittest.TestCase.tearDownClass()
87            def test_one(self):
88                pass
89            def test_two(self):
90                pass
91
92        result = self.runTests(Test, Test2)
93
94        self.assertEqual(Test.tearDownCalled, 1)
95        self.assertEqual(Test2.tearDownCalled, 1)
96        self.assertEqual(result.testsRun, 4)
97        self.assertEqual(len(result.errors), 0)
98
99    def test_error_in_setupclass(self):
100        class BrokenTest(unittest.TestCase):
101            @classmethod
102            def setUpClass(cls):
103                raise TypeError('foo')
104            def test_one(self):
105                pass
106            def test_two(self):
107                pass
108
109        result = self.runTests(BrokenTest)
110
111        self.assertEqual(result.testsRun, 0)
112        self.assertEqual(len(result.errors), 1)
113        error, _ = result.errors[0]
114        self.assertEqual(str(error),
115                    'setUpClass (%s.BrokenTest)' % __name__)
116
117    def test_error_in_teardown_class(self):
118        class Test(unittest.TestCase):
119            tornDown = 0
120            @classmethod
121            def tearDownClass(cls):
122                Test.tornDown += 1
123                raise TypeError('foo')
124            def test_one(self):
125                pass
126            def test_two(self):
127                pass
128
129        class Test2(unittest.TestCase):
130            tornDown = 0
131            @classmethod
132            def tearDownClass(cls):
133                Test2.tornDown += 1
134                raise TypeError('foo')
135            def test_one(self):
136                pass
137            def test_two(self):
138                pass
139
140        result = self.runTests(Test, Test2)
141        self.assertEqual(result.testsRun, 4)
142        self.assertEqual(len(result.errors), 2)
143        self.assertEqual(Test.tornDown, 1)
144        self.assertEqual(Test2.tornDown, 1)
145
146        error, _ = result.errors[0]
147        self.assertEqual(str(error),
148                    'tearDownClass (%s.Test)' % __name__)
149
150    def test_class_not_torndown_when_setup_fails(self):
151        class Test(unittest.TestCase):
152            tornDown = False
153            @classmethod
154            def setUpClass(cls):
155                raise TypeError
156            @classmethod
157            def tearDownClass(cls):
158                Test.tornDown = True
159                raise TypeError('foo')
160            def test_one(self):
161                pass
162
163        self.runTests(Test)
164        self.assertFalse(Test.tornDown)
165
166    def test_class_not_setup_or_torndown_when_skipped(self):
167        class Test(unittest.TestCase):
168            classSetUp = False
169            tornDown = False
170            @classmethod
171            def setUpClass(cls):
172                Test.classSetUp = True
173            @classmethod
174            def tearDownClass(cls):
175                Test.tornDown = True
176            def test_one(self):
177                pass
178
179        Test = unittest.skip("hop")(Test)
180        self.runTests(Test)
181        self.assertFalse(Test.classSetUp)
182        self.assertFalse(Test.tornDown)
183
184    def test_setup_teardown_order_with_pathological_suite(self):
185        results = []
186
187        class Module1(object):
188            @staticmethod
189            def setUpModule():
190                results.append('Module1.setUpModule')
191            @staticmethod
192            def tearDownModule():
193                results.append('Module1.tearDownModule')
194
195        class Module2(object):
196            @staticmethod
197            def setUpModule():
198                results.append('Module2.setUpModule')
199            @staticmethod
200            def tearDownModule():
201                results.append('Module2.tearDownModule')
202
203        class Test1(unittest.TestCase):
204            @classmethod
205            def setUpClass(cls):
206                results.append('setup 1')
207            @classmethod
208            def tearDownClass(cls):
209                results.append('teardown 1')
210            def testOne(self):
211                results.append('Test1.testOne')
212            def testTwo(self):
213                results.append('Test1.testTwo')
214
215        class Test2(unittest.TestCase):
216            @classmethod
217            def setUpClass(cls):
218                results.append('setup 2')
219            @classmethod
220            def tearDownClass(cls):
221                results.append('teardown 2')
222            def testOne(self):
223                results.append('Test2.testOne')
224            def testTwo(self):
225                results.append('Test2.testTwo')
226
227        class Test3(unittest.TestCase):
228            @classmethod
229            def setUpClass(cls):
230                results.append('setup 3')
231            @classmethod
232            def tearDownClass(cls):
233                results.append('teardown 3')
234            def testOne(self):
235                results.append('Test3.testOne')
236            def testTwo(self):
237                results.append('Test3.testTwo')
238
239        Test1.__module__ = Test2.__module__ = 'Module'
240        Test3.__module__ = 'Module2'
241        sys.modules['Module'] = Module1
242        sys.modules['Module2'] = Module2
243
244        first = unittest.TestSuite((Test1('testOne'),))
245        second = unittest.TestSuite((Test1('testTwo'),))
246        third = unittest.TestSuite((Test2('testOne'),))
247        fourth = unittest.TestSuite((Test2('testTwo'),))
248        fifth = unittest.TestSuite((Test3('testOne'),))
249        sixth = unittest.TestSuite((Test3('testTwo'),))
250        suite = unittest.TestSuite((first, second, third, fourth, fifth, sixth))
251
252        runner = self.getRunner()
253        result = runner.run(suite)
254        self.assertEqual(result.testsRun, 6)
255        self.assertEqual(len(result.errors), 0)
256
257        self.assertEqual(results,
258                         ['Module1.setUpModule', 'setup 1',
259                          'Test1.testOne', 'Test1.testTwo', 'teardown 1',
260                          'setup 2', 'Test2.testOne', 'Test2.testTwo',
261                          'teardown 2', 'Module1.tearDownModule',
262                          'Module2.setUpModule', 'setup 3',
263                          'Test3.testOne', 'Test3.testTwo',
264                          'teardown 3', 'Module2.tearDownModule'])
265
266    def test_setup_module(self):
267        class Module(object):
268            moduleSetup = 0
269            @staticmethod
270            def setUpModule():
271                Module.moduleSetup += 1
272
273        class Test(unittest.TestCase):
274            def test_one(self):
275                pass
276            def test_two(self):
277                pass
278        Test.__module__ = 'Module'
279        sys.modules['Module'] = Module
280
281        result = self.runTests(Test)
282        self.assertEqual(Module.moduleSetup, 1)
283        self.assertEqual(result.testsRun, 2)
284        self.assertEqual(len(result.errors), 0)
285
286    def test_error_in_setup_module(self):
287        class Module(object):
288            moduleSetup = 0
289            moduleTornDown = 0
290            @staticmethod
291            def setUpModule():
292                Module.moduleSetup += 1
293                raise TypeError('foo')
294            @staticmethod
295            def tearDownModule():
296                Module.moduleTornDown += 1
297
298        class Test(unittest.TestCase):
299            classSetUp = False
300            classTornDown = False
301            @classmethod
302            def setUpClass(cls):
303                Test.classSetUp = True
304            @classmethod
305            def tearDownClass(cls):
306                Test.classTornDown = True
307            def test_one(self):
308                pass
309            def test_two(self):
310                pass
311
312        class Test2(unittest.TestCase):
313            def test_one(self):
314                pass
315            def test_two(self):
316                pass
317        Test.__module__ = 'Module'
318        Test2.__module__ = 'Module'
319        sys.modules['Module'] = Module
320
321        result = self.runTests(Test, Test2)
322        self.assertEqual(Module.moduleSetup, 1)
323        self.assertEqual(Module.moduleTornDown, 0)
324        self.assertEqual(result.testsRun, 0)
325        self.assertFalse(Test.classSetUp)
326        self.assertFalse(Test.classTornDown)
327        self.assertEqual(len(result.errors), 1)
328        error, _ = result.errors[0]
329        self.assertEqual(str(error), 'setUpModule (Module)')
330
331    def test_testcase_with_missing_module(self):
332        class Test(unittest.TestCase):
333            def test_one(self):
334                pass
335            def test_two(self):
336                pass
337        Test.__module__ = 'Module'
338        sys.modules.pop('Module', None)
339
340        result = self.runTests(Test)
341        self.assertEqual(result.testsRun, 2)
342
343    def test_teardown_module(self):
344        class Module(object):
345            moduleTornDown = 0
346            @staticmethod
347            def tearDownModule():
348                Module.moduleTornDown += 1
349
350        class Test(unittest.TestCase):
351            def test_one(self):
352                pass
353            def test_two(self):
354                pass
355        Test.__module__ = 'Module'
356        sys.modules['Module'] = Module
357
358        result = self.runTests(Test)
359        self.assertEqual(Module.moduleTornDown, 1)
360        self.assertEqual(result.testsRun, 2)
361        self.assertEqual(len(result.errors), 0)
362
363    def test_error_in_teardown_module(self):
364        class Module(object):
365            moduleTornDown = 0
366            @staticmethod
367            def tearDownModule():
368                Module.moduleTornDown += 1
369                raise TypeError('foo')
370
371        class Test(unittest.TestCase):
372            classSetUp = False
373            classTornDown = False
374            @classmethod
375            def setUpClass(cls):
376                Test.classSetUp = True
377            @classmethod
378            def tearDownClass(cls):
379                Test.classTornDown = True
380            def test_one(self):
381                pass
382            def test_two(self):
383                pass
384
385        class Test2(unittest.TestCase):
386            def test_one(self):
387                pass
388            def test_two(self):
389                pass
390        Test.__module__ = 'Module'
391        Test2.__module__ = 'Module'
392        sys.modules['Module'] = Module
393
394        result = self.runTests(Test, Test2)
395        self.assertEqual(Module.moduleTornDown, 1)
396        self.assertEqual(result.testsRun, 4)
397        self.assertTrue(Test.classSetUp)
398        self.assertTrue(Test.classTornDown)
399        self.assertEqual(len(result.errors), 1)
400        error, _ = result.errors[0]
401        self.assertEqual(str(error), 'tearDownModule (Module)')
402
403    def test_skiptest_in_setupclass(self):
404        class Test(unittest.TestCase):
405            @classmethod
406            def setUpClass(cls):
407                raise unittest.SkipTest('foo')
408            def test_one(self):
409                pass
410            def test_two(self):
411                pass
412
413        result = self.runTests(Test)
414        self.assertEqual(result.testsRun, 0)
415        self.assertEqual(len(result.errors), 0)
416        self.assertEqual(len(result.skipped), 1)
417        skipped = result.skipped[0][0]
418        self.assertEqual(str(skipped), 'setUpClass (%s.Test)' % __name__)
419
420    def test_skiptest_in_setupmodule(self):
421        class Test(unittest.TestCase):
422            def test_one(self):
423                pass
424            def test_two(self):
425                pass
426
427        class Module(object):
428            @staticmethod
429            def setUpModule():
430                raise unittest.SkipTest('foo')
431
432        Test.__module__ = 'Module'
433        sys.modules['Module'] = Module
434
435        result = self.runTests(Test)
436        self.assertEqual(result.testsRun, 0)
437        self.assertEqual(len(result.errors), 0)
438        self.assertEqual(len(result.skipped), 1)
439        skipped = result.skipped[0][0]
440        self.assertEqual(str(skipped), 'setUpModule (Module)')
441
442    def test_suite_debug_executes_setups_and_teardowns(self):
443        ordering = []
444
445        class Module(object):
446            @staticmethod
447            def setUpModule():
448                ordering.append('setUpModule')
449            @staticmethod
450            def tearDownModule():
451                ordering.append('tearDownModule')
452
453        class Test(unittest.TestCase):
454            @classmethod
455            def setUpClass(cls):
456                ordering.append('setUpClass')
457            @classmethod
458            def tearDownClass(cls):
459                ordering.append('tearDownClass')
460            def test_something(self):
461                ordering.append('test_something')
462
463        Test.__module__ = 'Module'
464        sys.modules['Module'] = Module
465
466        suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
467        suite.debug()
468        expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'tearDownModule']
469        self.assertEqual(ordering, expectedOrder)
470
471    def test_suite_debug_propagates_exceptions(self):
472        class Module(object):
473            @staticmethod
474            def setUpModule():
475                if phase == 0:
476                    raise Exception('setUpModule')
477            @staticmethod
478            def tearDownModule():
479                if phase == 1:
480                    raise Exception('tearDownModule')
481
482        class Test(unittest.TestCase):
483            @classmethod
484            def setUpClass(cls):
485                if phase == 2:
486                    raise Exception('setUpClass')
487            @classmethod
488            def tearDownClass(cls):
489                if phase == 3:
490                    raise Exception('tearDownClass')
491            def test_something(self):
492                if phase == 4:
493                    raise Exception('test_something')
494
495        Test.__module__ = 'Module'
496        sys.modules['Module'] = Module
497
498        _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
499        suite = unittest.TestSuite()
500        suite.addTest(_suite)
501
502        messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something')
503        for phase, msg in enumerate(messages):
504            with self.assertRaisesRegexp(Exception, msg):
505                suite.debug()
506
507if __name__ == '__main__':
508    unittest.main()
509