1#!/usr/bin/env python
2# Copyright (c) 2014 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5import unittest
6import sys
7import os
8import optparse
9
10__all__ = []
11
12def FilterSuite(suite, predicate):
13  new_suite = suite.__class__()
14
15  for x in suite:
16    if isinstance(x, unittest.TestSuite):
17      subsuite = FilterSuite(x, predicate)
18      if subsuite.countTestCases() == 0:
19        continue
20
21      new_suite.addTest(subsuite)
22      continue
23
24    assert isinstance(x, unittest.TestCase)
25    if predicate(x):
26      new_suite.addTest(x)
27
28  return new_suite
29
30class _TestLoader(unittest.TestLoader):
31  def __init__(self, *args):
32    super(_TestLoader, self).__init__(*args)
33    self.discover_calls = []
34
35  def loadTestsFromModule(self, module, use_load_tests=True):
36    if module.__file__ != __file__:
37      return super(_TestLoader, self).loadTestsFromModule(
38          module, use_load_tests)
39
40    suite = unittest.TestSuite()
41    for discover_args in self.discover_calls:
42      subsuite = self.discover(*discover_args)
43      suite.addTest(subsuite)
44    return suite
45
46class _RunnerImpl(unittest.TextTestRunner):
47  def __init__(self, filters):
48    super(_RunnerImpl, self).__init__(verbosity=2)
49    self.filters = filters
50
51  def ShouldTestRun(self, test):
52    return not self.filters or any(name in test.id() for name in self.filters)
53
54  def run(self, suite):
55    filtered_test = FilterSuite(suite, self.ShouldTestRun)
56    return super(_RunnerImpl, self).run(filtered_test)
57
58
59class TestRunner(object):
60  def __init__(self):
61    self._loader = _TestLoader()
62
63  def AddDirectory(self, dir_path, test_file_pattern="*test.py"):
64    assert os.path.isdir(dir_path)
65
66    self._loader.discover_calls.append((dir_path, test_file_pattern, dir_path))
67
68  def Main(self, argv=None):
69    if argv is None:
70      argv = sys.argv
71
72    parser = optparse.OptionParser()
73    options, args = parser.parse_args(argv[1:])
74
75    runner = _RunnerImpl(filters=args)
76    return unittest.main(module=__name__, argv=[sys.argv[0]],
77                         testLoader=self._loader,
78                         testRunner=runner)
79