1import io
2import sys
3
4import unittest
5
6
7def resultFactory(*_):
8    return unittest.TestResult()
9
10
11class TestSetups(unittest.TestCase):
12
13    def getRunner(self):
14        return unittest.TextTestRunner(resultclass=resultFactory,
15                                          stream=io.StringIO())
16    def runTests(self, *cases):
17        suite = unittest.TestSuite()
18        for case in cases:
19            tests = unittest.defaultTestLoader.loadTestsFromTestCase(case)
20            suite.addTests(tests)
21
22        runner = self.getRunner()
23
24        # creating a nested suite exposes some potential bugs
25        realSuite = unittest.TestSuite()
26        realSuite.addTest(suite)
27        # adding empty suites to the end exposes potential bugs
28        suite.addTest(unittest.TestSuite())
29        realSuite.addTest(unittest.TestSuite())
30        return runner.run(realSuite)
31
32    def test_setup_class(self):
33        class Test(unittest.TestCase):
34            setUpCalled = 0
35            @classmethod
36            def setUpClass(cls):
37                Test.setUpCalled += 1
38                unittest.TestCase.setUpClass()
39            def test_one(self):
40                pass
41            def test_two(self):
42                pass
43
44        result = self.runTests(Test)
45
46        self.assertEqual(Test.setUpCalled, 1)
47        self.assertEqual(result.testsRun, 2)
48        self.assertEqual(len(result.errors), 0)
49
50    def test_teardown_class(self):
51        class Test(unittest.TestCase):
52            tearDownCalled = 0
53            @classmethod
54            def tearDownClass(cls):
55                Test.tearDownCalled += 1
56                unittest.TestCase.tearDownClass()
57            def test_one(self):
58                pass
59            def test_two(self):
60                pass
61
62        result = self.runTests(Test)
63
64        self.assertEqual(Test.tearDownCalled, 1)
65        self.assertEqual(result.testsRun, 2)
66        self.assertEqual(len(result.errors), 0)
67
68    def test_teardown_class_two_classes(self):
69        class Test(unittest.TestCase):
70            tearDownCalled = 0
71            @classmethod
72            def tearDownClass(cls):
73                Test.tearDownCalled += 1
74                unittest.TestCase.tearDownClass()
75            def test_one(self):
76                pass
77            def test_two(self):
78                pass
79
80        class Test2(unittest.TestCase):
81            tearDownCalled = 0
82            @classmethod
83            def tearDownClass(cls):
84                Test2.tearDownCalled += 1
85                unittest.TestCase.tearDownClass()
86            def test_one(self):
87                pass
88            def test_two(self):
89                pass
90
91        result = self.runTests(Test, Test2)
92
93        self.assertEqual(Test.tearDownCalled, 1)
94        self.assertEqual(Test2.tearDownCalled, 1)
95        self.assertEqual(result.testsRun, 4)
96        self.assertEqual(len(result.errors), 0)
97
98    def test_error_in_setupclass(self):
99        class BrokenTest(unittest.TestCase):
100            @classmethod
101            def setUpClass(cls):
102                raise TypeError('foo')
103            def test_one(self):
104                pass
105            def test_two(self):
106                pass
107
108        result = self.runTests(BrokenTest)
109
110        self.assertEqual(result.testsRun, 0)
111        self.assertEqual(len(result.errors), 1)
112        error, _ = result.errors[0]
113        self.assertEqual(str(error),
114                    'setUpClass (%s.%s)' % (__name__, BrokenTest.__qualname__))
115
116    def test_error_in_teardown_class(self):
117        class Test(unittest.TestCase):
118            tornDown = 0
119            @classmethod
120            def tearDownClass(cls):
121                Test.tornDown += 1
122                raise TypeError('foo')
123            def test_one(self):
124                pass
125            def test_two(self):
126                pass
127
128        class Test2(unittest.TestCase):
129            tornDown = 0
130            @classmethod
131            def tearDownClass(cls):
132                Test2.tornDown += 1
133                raise TypeError('foo')
134            def test_one(self):
135                pass
136            def test_two(self):
137                pass
138
139        result = self.runTests(Test, Test2)
140        self.assertEqual(result.testsRun, 4)
141        self.assertEqual(len(result.errors), 2)
142        self.assertEqual(Test.tornDown, 1)
143        self.assertEqual(Test2.tornDown, 1)
144
145        error, _ = result.errors[0]
146        self.assertEqual(str(error),
147                    'tearDownClass (%s.%s)' % (__name__, Test.__qualname__))
148
149    def test_class_not_torndown_when_setup_fails(self):
150        class Test(unittest.TestCase):
151            tornDown = False
152            @classmethod
153            def setUpClass(cls):
154                raise TypeError
155            @classmethod
156            def tearDownClass(cls):
157                Test.tornDown = True
158                raise TypeError('foo')
159            def test_one(self):
160                pass
161
162        self.runTests(Test)
163        self.assertFalse(Test.tornDown)
164
165    def test_class_not_setup_or_torndown_when_skipped(self):
166        class Test(unittest.TestCase):
167            classSetUp = False
168            tornDown = False
169            @classmethod
170            def setUpClass(cls):
171                Test.classSetUp = True
172            @classmethod
173            def tearDownClass(cls):
174                Test.tornDown = True
175            def test_one(self):
176                pass
177
178        Test = unittest.skip("hop")(Test)
179        self.runTests(Test)
180        self.assertFalse(Test.classSetUp)
181        self.assertFalse(Test.tornDown)
182
183    def test_setup_teardown_order_with_pathological_suite(self):
184        results = []
185
186        class Module1(object):
187            @staticmethod
188            def setUpModule():
189                results.append('Module1.setUpModule')
190            @staticmethod
191            def tearDownModule():
192                results.append('Module1.tearDownModule')
193
194        class Module2(object):
195            @staticmethod
196            def setUpModule():
197                results.append('Module2.setUpModule')
198            @staticmethod
199            def tearDownModule():
200                results.append('Module2.tearDownModule')
201
202        class Test1(unittest.TestCase):
203            @classmethod
204            def setUpClass(cls):
205                results.append('setup 1')
206            @classmethod
207            def tearDownClass(cls):
208                results.append('teardown 1')
209            def testOne(self):
210                results.append('Test1.testOne')
211            def testTwo(self):
212                results.append('Test1.testTwo')
213
214        class Test2(unittest.TestCase):
215            @classmethod
216            def setUpClass(cls):
217                results.append('setup 2')
218            @classmethod
219            def tearDownClass(cls):
220                results.append('teardown 2')
221            def testOne(self):
222                results.append('Test2.testOne')
223            def testTwo(self):
224                results.append('Test2.testTwo')
225
226        class Test3(unittest.TestCase):
227            @classmethod
228            def setUpClass(cls):
229                results.append('setup 3')
230            @classmethod
231            def tearDownClass(cls):
232                results.append('teardown 3')
233            def testOne(self):
234                results.append('Test3.testOne')
235            def testTwo(self):
236                results.append('Test3.testTwo')
237
238        Test1.__module__ = Test2.__module__ = 'Module'
239        Test3.__module__ = 'Module2'
240        sys.modules['Module'] = Module1
241        sys.modules['Module2'] = Module2
242
243        first = unittest.TestSuite((Test1('testOne'),))
244        second = unittest.TestSuite((Test1('testTwo'),))
245        third = unittest.TestSuite((Test2('testOne'),))
246        fourth = unittest.TestSuite((Test2('testTwo'),))
247        fifth = unittest.TestSuite((Test3('testOne'),))
248        sixth = unittest.TestSuite((Test3('testTwo'),))
249        suite = unittest.TestSuite((first, second, third, fourth, fifth, sixth))
250
251        runner = self.getRunner()
252        result = runner.run(suite)
253        self.assertEqual(result.testsRun, 6)
254        self.assertEqual(len(result.errors), 0)
255
256        self.assertEqual(results,
257                         ['Module1.setUpModule', 'setup 1',
258                          'Test1.testOne', 'Test1.testTwo', 'teardown 1',
259                          'setup 2', 'Test2.testOne', 'Test2.testTwo',
260                          'teardown 2', 'Module1.tearDownModule',
261                          'Module2.setUpModule', 'setup 3',
262                          'Test3.testOne', 'Test3.testTwo',
263                          'teardown 3', 'Module2.tearDownModule'])
264
265    def test_setup_module(self):
266        class Module(object):
267            moduleSetup = 0
268            @staticmethod
269            def setUpModule():
270                Module.moduleSetup += 1
271
272        class Test(unittest.TestCase):
273            def test_one(self):
274                pass
275            def test_two(self):
276                pass
277        Test.__module__ = 'Module'
278        sys.modules['Module'] = Module
279
280        result = self.runTests(Test)
281        self.assertEqual(Module.moduleSetup, 1)
282        self.assertEqual(result.testsRun, 2)
283        self.assertEqual(len(result.errors), 0)
284
285    def test_error_in_setup_module(self):
286        class Module(object):
287            moduleSetup = 0
288            moduleTornDown = 0
289            @staticmethod
290            def setUpModule():
291                Module.moduleSetup += 1
292                raise TypeError('foo')
293            @staticmethod
294            def tearDownModule():
295                Module.moduleTornDown += 1
296
297        class Test(unittest.TestCase):
298            classSetUp = False
299            classTornDown = False
300            @classmethod
301            def setUpClass(cls):
302                Test.classSetUp = True
303            @classmethod
304            def tearDownClass(cls):
305                Test.classTornDown = True
306            def test_one(self):
307                pass
308            def test_two(self):
309                pass
310
311        class Test2(unittest.TestCase):
312            def test_one(self):
313                pass
314            def test_two(self):
315                pass
316        Test.__module__ = 'Module'
317        Test2.__module__ = 'Module'
318        sys.modules['Module'] = Module
319
320        result = self.runTests(Test, Test2)
321        self.assertEqual(Module.moduleSetup, 1)
322        self.assertEqual(Module.moduleTornDown, 0)
323        self.assertEqual(result.testsRun, 0)
324        self.assertFalse(Test.classSetUp)
325        self.assertFalse(Test.classTornDown)
326        self.assertEqual(len(result.errors), 1)
327        error, _ = result.errors[0]
328        self.assertEqual(str(error), 'setUpModule (Module)')
329
330    def test_testcase_with_missing_module(self):
331        class Test(unittest.TestCase):
332            def test_one(self):
333                pass
334            def test_two(self):
335                pass
336        Test.__module__ = 'Module'
337        sys.modules.pop('Module', None)
338
339        result = self.runTests(Test)
340        self.assertEqual(result.testsRun, 2)
341
342    def test_teardown_module(self):
343        class Module(object):
344            moduleTornDown = 0
345            @staticmethod
346            def tearDownModule():
347                Module.moduleTornDown += 1
348
349        class Test(unittest.TestCase):
350            def test_one(self):
351                pass
352            def test_two(self):
353                pass
354        Test.__module__ = 'Module'
355        sys.modules['Module'] = Module
356
357        result = self.runTests(Test)
358        self.assertEqual(Module.moduleTornDown, 1)
359        self.assertEqual(result.testsRun, 2)
360        self.assertEqual(len(result.errors), 0)
361
362    def test_error_in_teardown_module(self):
363        class Module(object):
364            moduleTornDown = 0
365            @staticmethod
366            def tearDownModule():
367                Module.moduleTornDown += 1
368                raise TypeError('foo')
369
370        class Test(unittest.TestCase):
371            classSetUp = False
372            classTornDown = False
373            @classmethod
374            def setUpClass(cls):
375                Test.classSetUp = True
376            @classmethod
377            def tearDownClass(cls):
378                Test.classTornDown = True
379            def test_one(self):
380                pass
381            def test_two(self):
382                pass
383
384        class Test2(unittest.TestCase):
385            def test_one(self):
386                pass
387            def test_two(self):
388                pass
389        Test.__module__ = 'Module'
390        Test2.__module__ = 'Module'
391        sys.modules['Module'] = Module
392
393        result = self.runTests(Test, Test2)
394        self.assertEqual(Module.moduleTornDown, 1)
395        self.assertEqual(result.testsRun, 4)
396        self.assertTrue(Test.classSetUp)
397        self.assertTrue(Test.classTornDown)
398        self.assertEqual(len(result.errors), 1)
399        error, _ = result.errors[0]
400        self.assertEqual(str(error), 'tearDownModule (Module)')
401
402    def test_skiptest_in_setupclass(self):
403        class Test(unittest.TestCase):
404            @classmethod
405            def setUpClass(cls):
406                raise unittest.SkipTest('foo')
407            def test_one(self):
408                pass
409            def test_two(self):
410                pass
411
412        result = self.runTests(Test)
413        self.assertEqual(result.testsRun, 0)
414        self.assertEqual(len(result.errors), 0)
415        self.assertEqual(len(result.skipped), 1)
416        skipped = result.skipped[0][0]
417        self.assertEqual(str(skipped),
418                    'setUpClass (%s.%s)' % (__name__, Test.__qualname__))
419
420    def test_skiptest_in_setupmodule(self):
421        class Test(unittest.TestCase):
422            def test_one(self):
423                pass
424            def test_two(self):
425                pass
426
427        class Module(object):
428            @staticmethod
429            def setUpModule():
430                raise unittest.SkipTest('foo')
431
432        Test.__module__ = 'Module'
433        sys.modules['Module'] = Module
434
435        result = self.runTests(Test)
436        self.assertEqual(result.testsRun, 0)
437        self.assertEqual(len(result.errors), 0)
438        self.assertEqual(len(result.skipped), 1)
439        skipped = result.skipped[0][0]
440        self.assertEqual(str(skipped), 'setUpModule (Module)')
441
442    def test_suite_debug_executes_setups_and_teardowns(self):
443        ordering = []
444
445        class Module(object):
446            @staticmethod
447            def setUpModule():
448                ordering.append('setUpModule')
449            @staticmethod
450            def tearDownModule():
451                ordering.append('tearDownModule')
452
453        class Test(unittest.TestCase):
454            @classmethod
455            def setUpClass(cls):
456                ordering.append('setUpClass')
457            @classmethod
458            def tearDownClass(cls):
459                ordering.append('tearDownClass')
460            def test_something(self):
461                ordering.append('test_something')
462
463        Test.__module__ = 'Module'
464        sys.modules['Module'] = Module
465
466        suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
467        suite.debug()
468        expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'tearDownModule']
469        self.assertEqual(ordering, expectedOrder)
470
471    def test_suite_debug_propagates_exceptions(self):
472        class Module(object):
473            @staticmethod
474            def setUpModule():
475                if phase == 0:
476                    raise Exception('setUpModule')
477            @staticmethod
478            def tearDownModule():
479                if phase == 1:
480                    raise Exception('tearDownModule')
481
482        class Test(unittest.TestCase):
483            @classmethod
484            def setUpClass(cls):
485                if phase == 2:
486                    raise Exception('setUpClass')
487            @classmethod
488            def tearDownClass(cls):
489                if phase == 3:
490                    raise Exception('tearDownClass')
491            def test_something(self):
492                if phase == 4:
493                    raise Exception('test_something')
494
495        Test.__module__ = 'Module'
496        sys.modules['Module'] = Module
497
498        messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something')
499        for phase, msg in enumerate(messages):
500            _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
501            suite = unittest.TestSuite([_suite])
502            with self.assertRaisesRegex(Exception, msg):
503                suite.debug()
504
505
506if __name__ == '__main__':
507    unittest.main()
508