1"""TestSuite"""
2
3import sys
4import unittest
5from unittest2 import case, util
6
7__unittest = True
8
9
10class BaseTestSuite(unittest.TestSuite):
11    """A simple test suite that doesn't provide class or module shared fixtures.
12    """
13    def __init__(self, tests=()):
14        self._tests = []
15        self.addTests(tests)
16
17    def __repr__(self):
18        return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
19
20    def __eq__(self, other):
21        if not isinstance(other, self.__class__):
22            return NotImplemented
23        return list(self) == list(other)
24
25    def __ne__(self, other):
26        return not self == other
27
28    # Can't guarantee hash invariant, so flag as unhashable
29    __hash__ = None
30
31    def __iter__(self):
32        return iter(self._tests)
33
34    def countTestCases(self):
35        cases = 0
36        for test in self:
37            cases += test.countTestCases()
38        return cases
39
40    def addTest(self, test):
41        # sanity checks
42        if not hasattr(test, '__call__'):
43            raise TypeError("%r is not callable" % (repr(test),))
44        if isinstance(test, type) and issubclass(test,
45                                                 (case.TestCase, TestSuite)):
46            raise TypeError("TestCases and TestSuites must be instantiated "
47                            "before passing them to addTest()")
48        self._tests.append(test)
49
50    def addTests(self, tests):
51        if isinstance(tests, basestring):
52            raise TypeError("tests must be an iterable of tests, not a string")
53        for test in tests:
54            self.addTest(test)
55
56    def run(self, result):
57        for test in self:
58            if result.shouldStop:
59                break
60            test(result)
61        return result
62
63    def __call__(self, *args, **kwds):
64        return self.run(*args, **kwds)
65
66    def debug(self):
67        """Run the tests without collecting errors in a TestResult"""
68        for test in self:
69            test.debug()
70
71
72class TestSuite(BaseTestSuite):
73    """A test suite is a composite test consisting of a number of TestCases.
74
75    For use, create an instance of TestSuite, then add test case instances.
76    When all tests have been added, the suite can be passed to a test
77    runner, such as TextTestRunner. It will run the individual test cases
78    in the order in which they were added, aggregating the results. When
79    subclassing, do not forget to call the base class constructor.
80    """
81
82
83    def run(self, result):
84        self._wrapped_run(result)
85        self._tearDownPreviousClass(None, result)
86        self._handleModuleTearDown(result)
87        return result
88
89    def debug(self):
90        """Run the tests without collecting errors in a TestResult"""
91        debug = _DebugResult()
92        self._wrapped_run(debug, True)
93        self._tearDownPreviousClass(None, debug)
94        self._handleModuleTearDown(debug)
95
96    ################################
97    # private methods
98    def _wrapped_run(self, result, debug=False):
99        for test in self:
100            if result.shouldStop:
101                break
102
103            if _isnotsuite(test):
104                self._tearDownPreviousClass(test, result)
105                self._handleModuleFixture(test, result)
106                self._handleClassSetUp(test, result)
107                result._previousTestClass = test.__class__
108
109                if (getattr(test.__class__, '_classSetupFailed', False) or
110                    getattr(result, '_moduleSetUpFailed', False)):
111                    continue
112
113            if hasattr(test, '_wrapped_run'):
114                test._wrapped_run(result, debug)
115            elif not debug:
116                test(result)
117            else:
118                test.debug()
119
120    def _handleClassSetUp(self, test, result):
121        previousClass = getattr(result, '_previousTestClass', None)
122        currentClass = test.__class__
123        if currentClass == previousClass:
124            return
125        if result._moduleSetUpFailed:
126            return
127        if getattr(currentClass, "__unittest_skip__", False):
128            return
129
130        try:
131            currentClass._classSetupFailed = False
132        except TypeError:
133            # test may actually be a function
134            # so its class will be a builtin-type
135            pass
136
137        setUpClass = getattr(currentClass, 'setUpClass', None)
138        if setUpClass is not None:
139            try:
140                setUpClass()
141            except Exception, e:
142                if isinstance(result, _DebugResult):
143                    raise
144                currentClass._classSetupFailed = True
145                className = util.strclass(currentClass)
146                errorName = 'setUpClass (%s)' % className
147                self._addClassOrModuleLevelException(result, e, errorName)
148
149    def _get_previous_module(self, result):
150        previousModule = None
151        previousClass = getattr(result, '_previousTestClass', None)
152        if previousClass is not None:
153            previousModule = previousClass.__module__
154        return previousModule
155
156
157    def _handleModuleFixture(self, test, result):
158        previousModule = self._get_previous_module(result)
159        currentModule = test.__class__.__module__
160        if currentModule == previousModule:
161            return
162
163        self._handleModuleTearDown(result)
164
165
166        result._moduleSetUpFailed = False
167        try:
168            module = sys.modules[currentModule]
169        except KeyError:
170            return
171        setUpModule = getattr(module, 'setUpModule', None)
172        if setUpModule is not None:
173            try:
174                setUpModule()
175            except Exception, e:
176                if isinstance(result, _DebugResult):
177                    raise
178                result._moduleSetUpFailed = True
179                errorName = 'setUpModule (%s)' % currentModule
180                self._addClassOrModuleLevelException(result, e, errorName)
181
182    def _addClassOrModuleLevelException(self, result, exception, errorName):
183        error = _ErrorHolder(errorName)
184        addSkip = getattr(result, 'addSkip', None)
185        if addSkip is not None and isinstance(exception, case.SkipTest):
186            addSkip(error, str(exception))
187        else:
188            result.addError(error, sys.exc_info())
189
190    def _handleModuleTearDown(self, result):
191        previousModule = self._get_previous_module(result)
192        if previousModule is None:
193            return
194        if result._moduleSetUpFailed:
195            return
196
197        try:
198            module = sys.modules[previousModule]
199        except KeyError:
200            return
201
202        tearDownModule = getattr(module, 'tearDownModule', None)
203        if tearDownModule is not None:
204            try:
205                tearDownModule()
206            except Exception, e:
207                if isinstance(result, _DebugResult):
208                    raise
209                errorName = 'tearDownModule (%s)' % previousModule
210                self._addClassOrModuleLevelException(result, e, errorName)
211
212    def _tearDownPreviousClass(self, test, result):
213        previousClass = getattr(result, '_previousTestClass', None)
214        currentClass = test.__class__
215        if currentClass == previousClass:
216            return
217        if getattr(previousClass, '_classSetupFailed', False):
218            return
219        if getattr(result, '_moduleSetUpFailed', False):
220            return
221        if getattr(previousClass, "__unittest_skip__", False):
222            return
223
224        tearDownClass = getattr(previousClass, 'tearDownClass', None)
225        if tearDownClass is not None:
226            try:
227                tearDownClass()
228            except Exception, e:
229                if isinstance(result, _DebugResult):
230                    raise
231                className = util.strclass(previousClass)
232                errorName = 'tearDownClass (%s)' % className
233                self._addClassOrModuleLevelException(result, e, errorName)
234
235
236class _ErrorHolder(object):
237    """
238    Placeholder for a TestCase inside a result. As far as a TestResult
239    is concerned, this looks exactly like a unit test. Used to insert
240    arbitrary errors into a test suite run.
241    """
242    # Inspired by the ErrorHolder from Twisted:
243    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
244
245    # attribute used by TestResult._exc_info_to_string
246    failureException = None
247
248    def __init__(self, description):
249        self.description = description
250
251    def id(self):
252        return self.description
253
254    def shortDescription(self):
255        return None
256
257    def __repr__(self):
258        return "<ErrorHolder description=%r>" % (self.description,)
259
260    def __str__(self):
261        return self.id()
262
263    def run(self, result):
264        # could call result.addError(...) - but this test-like object
265        # shouldn't be run anyway
266        pass
267
268    def __call__(self, result):
269        return self.run(result)
270
271    def countTestCases(self):
272        return 0
273
274def _isnotsuite(test):
275    "A crude way to tell apart testcases and suites with duck-typing"
276    try:
277        iter(test)
278    except TypeError:
279        return True
280    return False
281
282
283class _DebugResult(object):
284    "Used by the TestSuite to hold previous class when running in debug."
285    _previousTestClass = None
286    _moduleSetUpFailed = False
287    shouldStop = False
288