1import sys
2
3from cStringIO import StringIO
4
5import unittest2
6from unittest2.test.support import resultFactory
7
8
9class TestSetups(unittest2.TestCase):
10
11    def getRunner(self):
12        return unittest2.TextTestRunner(resultclass=resultFactory,
13                                          stream=StringIO())
14    def runTests(self, *cases):
15        suite = unittest2.TestSuite()
16        for case in cases:
17            tests = unittest2.defaultTestLoader.loadTestsFromTestCase(case)
18            suite.addTests(tests)
19
20        runner = self.getRunner()
21
22        # creating a nested suite exposes some potential bugs
23        realSuite = unittest2.TestSuite()
24        realSuite.addTest(suite)
25        # adding empty suites to the end exposes potential bugs
26        suite.addTest(unittest2.TestSuite())
27        realSuite.addTest(unittest2.TestSuite())
28        return runner.run(realSuite)
29
30    def test_setup_class(self):
31        class Test(unittest2.TestCase):
32            setUpCalled = 0
33            @classmethod
34            def setUpClass(cls):
35                Test.setUpCalled += 1
36                unittest2.TestCase.setUpClass()
37            def test_one(self):
38                pass
39            def test_two(self):
40                pass
41
42        result = self.runTests(Test)
43
44        self.assertEqual(Test.setUpCalled, 1)
45        self.assertEqual(result.testsRun, 2)
46        self.assertEqual(len(result.errors), 0)
47
48    def test_teardown_class(self):
49        class Test(unittest2.TestCase):
50            tearDownCalled = 0
51            @classmethod
52            def tearDownClass(cls):
53                Test.tearDownCalled += 1
54                unittest2.TestCase.tearDownClass()
55            def test_one(self):
56                pass
57            def test_two(self):
58                pass
59
60        result = self.runTests(Test)
61
62        self.assertEqual(Test.tearDownCalled, 1)
63        self.assertEqual(result.testsRun, 2)
64        self.assertEqual(len(result.errors), 0)
65
66    def test_teardown_class_two_classes(self):
67        class Test(unittest2.TestCase):
68            tearDownCalled = 0
69            @classmethod
70            def tearDownClass(cls):
71                Test.tearDownCalled += 1
72                unittest2.TestCase.tearDownClass()
73            def test_one(self):
74                pass
75            def test_two(self):
76                pass
77
78        class Test2(unittest2.TestCase):
79            tearDownCalled = 0
80            @classmethod
81            def tearDownClass(cls):
82                Test2.tearDownCalled += 1
83                unittest2.TestCase.tearDownClass()
84            def test_one(self):
85                pass
86            def test_two(self):
87                pass
88
89        result = self.runTests(Test, Test2)
90
91        self.assertEqual(Test.tearDownCalled, 1)
92        self.assertEqual(Test2.tearDownCalled, 1)
93        self.assertEqual(result.testsRun, 4)
94        self.assertEqual(len(result.errors), 0)
95
96    def test_error_in_setupclass(self):
97        class BrokenTest(unittest2.TestCase):
98            @classmethod
99            def setUpClass(cls):
100                raise TypeError('foo')
101            def test_one(self):
102                pass
103            def test_two(self):
104                pass
105
106        result = self.runTests(BrokenTest)
107
108        self.assertEqual(result.testsRun, 0)
109        self.assertEqual(len(result.errors), 1)
110        error, _ = result.errors[0]
111        self.assertEqual(str(error),
112                    'setUpClass (%s.BrokenTest)' % __name__)
113
114    def test_error_in_teardown_class(self):
115        class Test(unittest2.TestCase):
116            tornDown = 0
117            @classmethod
118            def tearDownClass(cls):
119                Test.tornDown += 1
120                raise TypeError('foo')
121            def test_one(self):
122                pass
123            def test_two(self):
124                pass
125
126        class Test2(unittest2.TestCase):
127            tornDown = 0
128            @classmethod
129            def tearDownClass(cls):
130                Test2.tornDown += 1
131                raise TypeError('foo')
132            def test_one(self):
133                pass
134            def test_two(self):
135                pass
136
137        result = self.runTests(Test, Test2)
138        self.assertEqual(result.testsRun, 4)
139        self.assertEqual(len(result.errors), 2)
140        self.assertEqual(Test.tornDown, 1)
141        self.assertEqual(Test2.tornDown, 1)
142
143        error, _ = result.errors[0]
144        self.assertEqual(str(error),
145                    'tearDownClass (%s.Test)' % __name__)
146
147    def test_class_not_torndown_when_setup_fails(self):
148        class Test(unittest2.TestCase):
149            tornDown = False
150            @classmethod
151            def setUpClass(cls):
152                raise TypeError
153            @classmethod
154            def tearDownClass(cls):
155                Test.tornDown = True
156                raise TypeError('foo')
157            def test_one(self):
158                pass
159
160        self.runTests(Test)
161        self.assertFalse(Test.tornDown)
162
163    def test_class_not_setup_or_torndown_when_skipped(self):
164        class Test(unittest2.TestCase):
165            classSetUp = False
166            tornDown = False
167            @classmethod
168            def setUpClass(cls):
169                Test.classSetUp = True
170            @classmethod
171            def tearDownClass(cls):
172                Test.tornDown = True
173            def test_one(self):
174                pass
175
176        Test = unittest2.skip("hop")(Test)
177        self.runTests(Test)
178        self.assertFalse(Test.classSetUp)
179        self.assertFalse(Test.tornDown)
180
181    def test_setup_teardown_order_with_pathological_suite(self):
182        results = []
183
184        class Module1(object):
185            @staticmethod
186            def setUpModule():
187                results.append('Module1.setUpModule')
188            @staticmethod
189            def tearDownModule():
190                results.append('Module1.tearDownModule')
191
192        class Module2(object):
193            @staticmethod
194            def setUpModule():
195                results.append('Module2.setUpModule')
196            @staticmethod
197            def tearDownModule():
198                results.append('Module2.tearDownModule')
199
200        class Test1(unittest2.TestCase):
201            @classmethod
202            def setUpClass(cls):
203                results.append('setup 1')
204            @classmethod
205            def tearDownClass(cls):
206                results.append('teardown 1')
207            def testOne(self):
208                results.append('Test1.testOne')
209            def testTwo(self):
210                results.append('Test1.testTwo')
211
212        class Test2(unittest2.TestCase):
213            @classmethod
214            def setUpClass(cls):
215                results.append('setup 2')
216            @classmethod
217            def tearDownClass(cls):
218                results.append('teardown 2')
219            def testOne(self):
220                results.append('Test2.testOne')
221            def testTwo(self):
222                results.append('Test2.testTwo')
223
224        class Test3(unittest2.TestCase):
225            @classmethod
226            def setUpClass(cls):
227                results.append('setup 3')
228            @classmethod
229            def tearDownClass(cls):
230                results.append('teardown 3')
231            def testOne(self):
232                results.append('Test3.testOne')
233            def testTwo(self):
234                results.append('Test3.testTwo')
235
236        Test1.__module__ = Test2.__module__ = 'Module'
237        Test3.__module__ = 'Module2'
238        sys.modules['Module'] = Module1
239        sys.modules['Module2'] = Module2
240
241        first = unittest2.TestSuite((Test1('testOne'),))
242        second = unittest2.TestSuite((Test1('testTwo'),))
243        third = unittest2.TestSuite((Test2('testOne'),))
244        fourth = unittest2.TestSuite((Test2('testTwo'),))
245        fifth = unittest2.TestSuite((Test3('testOne'),))
246        sixth = unittest2.TestSuite((Test3('testTwo'),))
247        suite = unittest2.TestSuite((first, second, third, fourth, fifth, sixth))
248
249        runner = self.getRunner()
250        result = runner.run(suite)
251        self.assertEqual(result.testsRun, 6)
252        self.assertEqual(len(result.errors), 0)
253
254        self.assertEqual(results,
255                         ['Module1.setUpModule', 'setup 1',
256                          'Test1.testOne', 'Test1.testTwo', 'teardown 1',
257                          'setup 2', 'Test2.testOne', 'Test2.testTwo',
258                          'teardown 2', 'Module1.tearDownModule',
259                          'Module2.setUpModule', 'setup 3',
260                          'Test3.testOne', 'Test3.testTwo',
261                          'teardown 3', 'Module2.tearDownModule'])
262
263    def test_setup_module(self):
264        class Module(object):
265            moduleSetup = 0
266            @staticmethod
267            def setUpModule():
268                Module.moduleSetup += 1
269
270        class Test(unittest2.TestCase):
271            def test_one(self):
272                pass
273            def test_two(self):
274                pass
275        Test.__module__ = 'Module'
276        sys.modules['Module'] = Module
277
278        result = self.runTests(Test)
279        self.assertEqual(Module.moduleSetup, 1)
280        self.assertEqual(result.testsRun, 2)
281        self.assertEqual(len(result.errors), 0)
282
283    def test_error_in_setup_module(self):
284        class Module(object):
285            moduleSetup = 0
286            moduleTornDown = 0
287            @staticmethod
288            def setUpModule():
289                Module.moduleSetup += 1
290                raise TypeError('foo')
291            @staticmethod
292            def tearDownModule():
293                Module.moduleTornDown += 1
294
295        class Test(unittest2.TestCase):
296            classSetUp = False
297            classTornDown = False
298            @classmethod
299            def setUpClass(cls):
300                Test.classSetUp = True
301            @classmethod
302            def tearDownClass(cls):
303                Test.classTornDown = True
304            def test_one(self):
305                pass
306            def test_two(self):
307                pass
308
309        class Test2(unittest2.TestCase):
310            def test_one(self):
311                pass
312            def test_two(self):
313                pass
314        Test.__module__ = 'Module'
315        Test2.__module__ = 'Module'
316        sys.modules['Module'] = Module
317
318        result = self.runTests(Test, Test2)
319        self.assertEqual(Module.moduleSetup, 1)
320        self.assertEqual(Module.moduleTornDown, 0)
321        self.assertEqual(result.testsRun, 0)
322        self.assertFalse(Test.classSetUp)
323        self.assertFalse(Test.classTornDown)
324        self.assertEqual(len(result.errors), 1)
325        error, _ = result.errors[0]
326        self.assertEqual(str(error), 'setUpModule (Module)')
327
328    def test_testcase_with_missing_module(self):
329        class Test(unittest2.TestCase):
330            def test_one(self):
331                pass
332            def test_two(self):
333                pass
334        Test.__module__ = 'Module'
335        sys.modules.pop('Module', None)
336
337        result = self.runTests(Test)
338        self.assertEqual(result.testsRun, 2)
339
340    def test_teardown_module(self):
341        class Module(object):
342            moduleTornDown = 0
343            @staticmethod
344            def tearDownModule():
345                Module.moduleTornDown += 1
346
347        class Test(unittest2.TestCase):
348            def test_one(self):
349                pass
350            def test_two(self):
351                pass
352        Test.__module__ = 'Module'
353        sys.modules['Module'] = Module
354
355        result = self.runTests(Test)
356        self.assertEqual(Module.moduleTornDown, 1)
357        self.assertEqual(result.testsRun, 2)
358        self.assertEqual(len(result.errors), 0)
359
360    def test_error_in_teardown_module(self):
361        class Module(object):
362            moduleTornDown = 0
363            @staticmethod
364            def tearDownModule():
365                Module.moduleTornDown += 1
366                raise TypeError('foo')
367
368        class Test(unittest2.TestCase):
369            classSetUp = False
370            classTornDown = False
371            @classmethod
372            def setUpClass(cls):
373                Test.classSetUp = True
374            @classmethod
375            def tearDownClass(cls):
376                Test.classTornDown = True
377            def test_one(self):
378                pass
379            def test_two(self):
380                pass
381
382        class Test2(unittest2.TestCase):
383            def test_one(self):
384                pass
385            def test_two(self):
386                pass
387        Test.__module__ = 'Module'
388        Test2.__module__ = 'Module'
389        sys.modules['Module'] = Module
390
391        result = self.runTests(Test, Test2)
392        self.assertEqual(Module.moduleTornDown, 1)
393        self.assertEqual(result.testsRun, 4)
394        self.assertTrue(Test.classSetUp)
395        self.assertTrue(Test.classTornDown)
396        self.assertEqual(len(result.errors), 1)
397        error, _ = result.errors[0]
398        self.assertEqual(str(error), 'tearDownModule (Module)')
399
400    def test_skiptest_in_setupclass(self):
401        class Test(unittest2.TestCase):
402            @classmethod
403            def setUpClass(cls):
404                raise unittest2.SkipTest('foo')
405            def test_one(self):
406                pass
407            def test_two(self):
408                pass
409
410        result = self.runTests(Test)
411        self.assertEqual(result.testsRun, 0)
412        self.assertEqual(len(result.errors), 0)
413        self.assertEqual(len(result.skipped), 1)
414        skipped = result.skipped[0][0]
415        self.assertEqual(str(skipped), 'setUpClass (%s.Test)' % __name__)
416
417    def test_skiptest_in_setupmodule(self):
418        class Test(unittest2.TestCase):
419            def test_one(self):
420                pass
421            def test_two(self):
422                pass
423
424        class Module(object):
425            @staticmethod
426            def setUpModule():
427                raise unittest2.SkipTest('foo')
428
429        Test.__module__ = 'Module'
430        sys.modules['Module'] = Module
431
432        result = self.runTests(Test)
433        self.assertEqual(result.testsRun, 0)
434        self.assertEqual(len(result.errors), 0)
435        self.assertEqual(len(result.skipped), 1)
436        skipped = result.skipped[0][0]
437        self.assertEqual(str(skipped), 'setUpModule (Module)')
438
439    def test_suite_debug_executes_setups_and_teardowns(self):
440        ordering = []
441
442        class Module(object):
443            @staticmethod
444            def setUpModule():
445                ordering.append('setUpModule')
446            @staticmethod
447            def tearDownModule():
448                ordering.append('tearDownModule')
449
450        class Test(unittest2.TestCase):
451            @classmethod
452            def setUpClass(cls):
453                ordering.append('setUpClass')
454            @classmethod
455            def tearDownClass(cls):
456                ordering.append('tearDownClass')
457            def test_something(self):
458                ordering.append('test_something')
459
460        Test.__module__ = 'Module'
461        sys.modules['Module'] = Module
462
463        suite = unittest2.defaultTestLoader.loadTestsFromTestCase(Test)
464        suite.debug()
465        expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'tearDownModule']
466        self.assertEqual(ordering, expectedOrder)
467
468    def test_suite_debug_propagates_exceptions(self):
469        class Module(object):
470            @staticmethod
471            def setUpModule():
472                if phase == 0:
473                    raise Exception('setUpModule')
474            @staticmethod
475            def tearDownModule():
476                if phase == 1:
477                    raise Exception('tearDownModule')
478
479        class Test(unittest2.TestCase):
480            @classmethod
481            def setUpClass(cls):
482                if phase == 2:
483                    raise Exception('setUpClass')
484            @classmethod
485            def tearDownClass(cls):
486                if phase == 3:
487                    raise Exception('tearDownClass')
488            def test_something(self):
489                if phase == 4:
490                    raise Exception('test_something')
491
492        Test.__module__ = 'Module'
493        sys.modules['Module'] = Module
494
495        _suite = unittest2.defaultTestLoader.loadTestsFromTestCase(Test)
496        suite = unittest2.TestSuite()
497
498        # nesting a suite again exposes a bug in the initial implementation
499        suite.addTest(_suite)
500        messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something')
501        for phase, msg in enumerate(messages):
502            self.assertRaisesRegexp(Exception, msg, suite.debug)
503