1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8import unittest
9
10from fnmatch import fnmatch
11
12from unittest2 import case, suite
13
14try:
15    from os.path import relpath
16except ImportError:
17    from unittest2.compatibility import relpath
18
19__unittest = True
20
21
22def _CmpToKey(mycmp):
23    'Convert a cmp= function into a key= function'
24    class K(object):
25        def __init__(self, obj):
26            self.obj = obj
27        def __lt__(self, other):
28            return mycmp(self.obj, other.obj) == -1
29    return K
30
31
32# what about .pyc or .pyo (etc)
33# we would need to avoid loading the same tests multiple times
34# from '.py', '.pyc' *and* '.pyo'
35VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
36
37
38def _make_failed_import_test(name, suiteClass):
39    message = 'Failed to import test module: %s' % name
40    if hasattr(traceback, 'format_exc'):
41        # Python 2.3 compatibility
42        # format_exc returns two frames of discover.py as well
43        message += '\n%s' % traceback.format_exc()
44    return _make_failed_test('ModuleImportFailure', name, ImportError(message),
45                             suiteClass)
46
47def _make_failed_load_tests(name, exception, suiteClass):
48    return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
49
50def _make_failed_test(classname, methodname, exception, suiteClass):
51    def testFailure(self):
52        raise exception
53    attrs = {methodname: testFailure}
54    TestClass = type(classname, (case.TestCase,), attrs)
55    return suiteClass((TestClass(methodname),))
56
57
58class TestLoader(unittest.TestLoader):
59    """
60    This class is responsible for loading tests according to various criteria
61    and returning them wrapped in a TestSuite
62    """
63    testMethodPrefix = 'test'
64    sortTestMethodsUsing = cmp
65    suiteClass = suite.TestSuite
66    _top_level_dir = None
67
68    def loadTestsFromTestCase(self, testCaseClass):
69        """Return a suite of all tests cases contained in testCaseClass"""
70        if issubclass(testCaseClass, suite.TestSuite):
71            raise TypeError("Test cases should not be derived from TestSuite."
72                            " Maybe you meant to derive from TestCase?")
73        testCaseNames = self.getTestCaseNames(testCaseClass)
74        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
75            testCaseNames = ['runTest']
76        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
77        return loaded_suite
78
79    def loadTestsFromModule(self, module, use_load_tests=True):
80        """Return a suite of all tests cases contained in the given module"""
81        tests = []
82        for name in dir(module):
83            obj = getattr(module, name)
84            if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
85                tests.append(self.loadTestsFromTestCase(obj))
86
87        load_tests = getattr(module, 'load_tests', None)
88        tests = self.suiteClass(tests)
89        if use_load_tests and load_tests is not None:
90            try:
91                return load_tests(self, tests, None)
92            except Exception, e:
93                return _make_failed_load_tests(module.__name__, e,
94                                               self.suiteClass)
95        return tests
96
97    def loadTestsFromName(self, name, module=None):
98        """Return a suite of all tests cases given a string specifier.
99
100        The name may resolve either to a module, a test case class, a
101        test method within a test case class, or a callable object which
102        returns a TestCase or TestSuite instance.
103
104        The method optionally resolves the names relative to a given module.
105        """
106        parts = name.split('.')
107        if module is None:
108            parts_copy = parts[:]
109            while parts_copy:
110                try:
111                    module = __import__('.'.join(parts_copy))
112                    break
113                except ImportError:
114                    del parts_copy[-1]
115                    if not parts_copy:
116                        raise
117            parts = parts[1:]
118        obj = module
119        for part in parts:
120            parent, obj = obj, getattr(obj, part)
121
122        if isinstance(obj, types.ModuleType):
123            return self.loadTestsFromModule(obj)
124        elif isinstance(obj, type) and issubclass(obj, unittest.TestCase):
125            return self.loadTestsFromTestCase(obj)
126        elif (isinstance(obj, types.UnboundMethodType) and
127              isinstance(parent, type) and
128              issubclass(parent, case.TestCase)):
129            return self.suiteClass([parent(obj.__name__)])
130        elif isinstance(obj, unittest.TestSuite):
131            return obj
132        elif hasattr(obj, '__call__'):
133            test = obj()
134            if isinstance(test, unittest.TestSuite):
135                return test
136            elif isinstance(test, unittest.TestCase):
137                return self.suiteClass([test])
138            else:
139                raise TypeError("calling %s returned %s, not a test" %
140                                (obj, test))
141        else:
142            raise TypeError("don't know how to make test from: %s" % obj)
143
144    def loadTestsFromNames(self, names, module=None):
145        """Return a suite of all tests cases found using the given sequence
146        of string specifiers. See 'loadTestsFromName()'.
147        """
148        suites = [self.loadTestsFromName(name, module) for name in names]
149        return self.suiteClass(suites)
150
151    def getTestCaseNames(self, testCaseClass):
152        """Return a sorted sequence of method names found within testCaseClass
153        """
154        def isTestMethod(attrname, testCaseClass=testCaseClass,
155                         prefix=self.testMethodPrefix):
156            return attrname.startswith(prefix) and \
157                hasattr(getattr(testCaseClass, attrname), '__call__')
158        testFnNames = filter(isTestMethod, dir(testCaseClass))
159        if self.sortTestMethodsUsing:
160            testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
161        return testFnNames
162
163    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
164        """Find and return all test modules from the specified start
165        directory, recursing into subdirectories to find them. Only test files
166        that match the pattern will be loaded. (Using shell style pattern
167        matching.)
168
169        All test modules must be importable from the top level of the project.
170        If the start directory is not the top level directory then the top
171        level directory must be specified separately.
172
173        If a test package name (directory with '__init__.py') matches the
174        pattern then the package will be checked for a 'load_tests' function. If
175        this exists then it will be called with loader, tests, pattern.
176
177        If load_tests exists then discovery does  *not* recurse into the package,
178        load_tests is responsible for loading all tests in the package.
179
180        The pattern is deliberately not stored as a loader attribute so that
181        packages can continue discovery themselves. top_level_dir is stored so
182        load_tests does not need to pass this argument in to loader.discover().
183        """
184        set_implicit_top = False
185        if top_level_dir is None and self._top_level_dir is not None:
186            # make top_level_dir optional if called from load_tests in a package
187            top_level_dir = self._top_level_dir
188        elif top_level_dir is None:
189            set_implicit_top = True
190            top_level_dir = start_dir
191
192        top_level_dir = os.path.abspath(top_level_dir)
193
194        if not top_level_dir in sys.path:
195            # all test modules must be importable from the top level directory
196            # should we *unconditionally* put the start directory in first
197            # in sys.path to minimise likelihood of conflicts between installed
198            # modules and development versions?
199            sys.path.insert(0, top_level_dir)
200        self._top_level_dir = top_level_dir
201
202        is_not_importable = False
203        if os.path.isdir(os.path.abspath(start_dir)):
204            start_dir = os.path.abspath(start_dir)
205            if start_dir != top_level_dir:
206                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
207        else:
208            # support for discovery from dotted module names
209            try:
210                __import__(start_dir)
211            except ImportError:
212                is_not_importable = True
213            else:
214                the_module = sys.modules[start_dir]
215                top_part = start_dir.split('.')[0]
216                start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
217                if set_implicit_top:
218                    self._top_level_dir = os.path.abspath(os.path.dirname(os.path.dirname(sys.modules[top_part].__file__)))
219                    sys.path.remove(top_level_dir)
220
221        if is_not_importable:
222            raise ImportError('Start directory is not importable: %r' % start_dir)
223
224        tests = list(self._find_tests(start_dir, pattern))
225        return self.suiteClass(tests)
226
227    def _get_name_from_path(self, path):
228        path = os.path.splitext(os.path.normpath(path))[0]
229
230        _relpath = relpath(path, self._top_level_dir)
231        assert not os.path.isabs(_relpath), "Path must be within the project"
232        assert not _relpath.startswith('..'), "Path must be within the project"
233
234        name = _relpath.replace(os.path.sep, '.')
235        return name
236
237    def _get_module_from_name(self, name):
238        __import__(name)
239        return sys.modules[name]
240
241    def _match_path(self, path, full_path, pattern):
242        # override this method to use alternative matching strategy
243        return fnmatch(path, pattern)
244
245    def _find_tests(self, start_dir, pattern):
246        """Used by discovery. Yields test suites it loads."""
247        paths = os.listdir(start_dir)
248
249        for path in paths:
250            full_path = os.path.join(start_dir, path)
251            if os.path.isfile(full_path):
252                if not VALID_MODULE_NAME.match(path):
253                    # valid Python identifiers only
254                    continue
255                if not self._match_path(path, full_path, pattern):
256                    continue
257                # if the test file matches, load it
258                name = self._get_name_from_path(full_path)
259                try:
260                    module = self._get_module_from_name(name)
261                except:
262                    yield _make_failed_import_test(name, self.suiteClass)
263                else:
264                    mod_file = os.path.abspath(getattr(module, '__file__', full_path))
265                    realpath = os.path.splitext(mod_file)[0]
266                    fullpath_noext = os.path.splitext(full_path)[0]
267                    if realpath.lower() != fullpath_noext.lower():
268                        module_dir = os.path.dirname(realpath)
269                        mod_name = os.path.splitext(os.path.basename(full_path))[0]
270                        expected_dir = os.path.dirname(full_path)
271                        msg = ("%r module incorrectly imported from %r. Expected %r. "
272                               "Is this module globally installed?")
273                        raise ImportError(msg % (mod_name, module_dir, expected_dir))
274                    yield self.loadTestsFromModule(module)
275            elif os.path.isdir(full_path):
276                if not os.path.isfile(os.path.join(full_path, '__init__.py')):
277                    continue
278
279                load_tests = None
280                tests = None
281                if fnmatch(path, pattern):
282                    # only check load_tests if the package directory itself matches the filter
283                    name = self._get_name_from_path(full_path)
284                    package = self._get_module_from_name(name)
285                    load_tests = getattr(package, 'load_tests', None)
286                    tests = self.loadTestsFromModule(package, use_load_tests=False)
287
288                if load_tests is None:
289                    if tests is not None:
290                        # tests loaded from package file
291                        yield tests
292                    # recurse into the package
293                    for test in self._find_tests(full_path, pattern):
294                        yield test
295                else:
296                    try:
297                        yield load_tests(self, tests, pattern)
298                    except Exception, e:
299                        yield _make_failed_load_tests(package.__name__, e,
300                                                      self.suiteClass)
301
302defaultTestLoader = TestLoader()
303
304
305def _makeLoader(prefix, sortUsing, suiteClass=None):
306    loader = TestLoader()
307    loader.sortTestMethodsUsing = sortUsing
308    loader.testMethodPrefix = prefix
309    if suiteClass:
310        loader.suiteClass = suiteClass
311    return loader
312
313def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
314    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
315
316def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
317              suiteClass=suite.TestSuite):
318    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
319
320def findTestCases(module, prefix='test', sortUsing=cmp,
321                  suiteClass=suite.TestSuite):
322    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
323