1"""Test case implementation"""
3import collections
4import sys
5import functools
6import difflib
7import pprint
8import re
9import types
10import warnings
12from . import result
13from .util import (
14    strclass, safe_repr, unorderable_list_difference,
15    _count_diff_all_purpose, _count_diff_hashable
19__unittest = True
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32    pass
34class _ExpectedFailure(Exception):
35    """
36    Raise this when a test is expected to fail.
38    This is an implementation detail.
39    """
41    def __init__(self, exc_info):
42        super(_ExpectedFailure, self).__init__()
43        self.exc_info = exc_info
45class _UnexpectedSuccess(Exception):
46    """
47    The test was supposed to fail, but it didn't!
48    """
49    pass
51def _id(obj):
52    return obj
54def skip(reason):
55    """
56    Unconditionally skip a test.
57    """
58    def decorator(test_item):
59        if not isinstance(test_item, (type, types.ClassType)):
60            @functools.wraps(test_item)
61            def skip_wrapper(*args, **kwargs):
62                raise SkipTest(reason)
63            test_item = skip_wrapper
65        test_item.__unittest_skip__ = True
66        test_item.__unittest_skip_why__ = reason
67        return test_item
68    return decorator
70def skipIf(condition, reason):
71    """
72    Skip a test if the condition is true.
73    """
74    if condition:
75        return skip(reason)
76    return _id
78def skipUnless(condition, reason):
79    """
80    Skip a test unless the condition is true.
81    """
82    if not condition:
83        return skip(reason)
84    return _id
87def expectedFailure(func):
88    @functools.wraps(func)
89    def wrapper(*args, **kwargs):
90        try:
91            func(*args, **kwargs)
92        except Exception:
93            raise _ExpectedFailure(sys.exc_info())
94        raise _UnexpectedSuccess
95    return wrapper
98class _AssertRaisesContext(object):
99    """A context manager used to implement TestCase.assertRaises* methods."""
101    def __init__(self, expected, test_case, expected_regexp=None):
102        self.expected = expected
103        self.failureException = test_case.failureException
104        self.expected_regexp = expected_regexp
106    def __enter__(self):
107        return self
109    def __exit__(self, exc_type, exc_value, tb):
110        if exc_type is None:
111            try:
112                exc_name = self.expected.__name__
113            except AttributeError:
114                exc_name = str(self.expected)
115            raise self.failureException(
116                "{0} not raised".format(exc_name))
117        if not issubclass(exc_type, self.expected):
118            # let unexpected exceptions pass through
119            return False
120        self.exception = exc_value # store for later retrieval
121        if self.expected_regexp is None:
122            return True
124        expected_regexp = self.expected_regexp
125        if isinstance(expected_regexp, basestring):
126            expected_regexp = re.compile(expected_regexp)
127        if not expected_regexp.search(str(exc_value)):
128            raise self.failureException('"%s" does not match "%s"' %
129                     (expected_regexp.pattern, str(exc_value)))
130        return True
133class TestCase(object):
134    """A class whose instances are single test cases.
136    By default, the test code itself should be placed in a method named
137    'runTest'.
139    If the fixture may be used for many test cases, create as
140    many test methods as are needed. When instantiating such a TestCase
141    subclass, specify in the constructor arguments the name of the test method
142    that the instance is to execute.
144    Test authors should subclass TestCase for their own tests. Construction
145    and deconstruction of the test's environment ('fixture') can be
146    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
148    If it is necessary to override the __init__ method, the base class
149    __init__ method must always be called. It is important that subclasses
150    should not change the signature of their __init__ method, since instances
151    of the classes are instantiated automatically by parts of the framework
152    in order to be run.
154    When subclassing TestCase, you can set these attributes:
155    * failureException: determines which exception will be raised when
156        the instance's assertion methods fail; test methods raising this
157        exception will be deemed to have 'failed' rather than 'errored'.
158    * longMessage: determines whether long messages (including repr of
159        objects used in assert methods) will be printed on failure in *addition*
160        to any explicit message passed.
161    * maxDiff: sets the maximum length of a diff in failure messages
162        by assert methods using difflib. It is looked up as an instance
163        attribute so can be configured by individual tests if required.
164    """
166    failureException = AssertionError
168    longMessage = False
170    maxDiff = 80*8
172    # If a string is longer than _diffThreshold, use normal comparison instead
173    # of difflib.  See #11763.
174    _diffThreshold = 2**16
176    # Attribute used by TestSuite for classSetUp
178    _classSetupFailed = False
180    def __init__(self, methodName='runTest'):
181        """Create an instance of the class that will use the named test
182           method when executed. Raises a ValueError if the instance does
183           not have a method with the specified name.
184        """
185        self._testMethodName = methodName
186        self._resultForDoCleanups = None
187        try:
188            testMethod = getattr(self, methodName)
189        except AttributeError:
190            raise ValueError("no such test method in %s: %s" %
191                  (self.__class__, methodName))
192        self._testMethodDoc = testMethod.__doc__
193        self._cleanups = []
195        # Map types to custom assertEqual functions that will compare
196        # instances of said type in more detail to generate a more useful
197        # error message.
198        self._type_equality_funcs = {}
199        self.addTypeEqualityFunc(dict, 'assertDictEqual')
200        self.addTypeEqualityFunc(list, 'assertListEqual')
201        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
202        self.addTypeEqualityFunc(set, 'assertSetEqual')
203        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
204        try:
205            self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
206        except NameError:
207            # No unicode support in this build
208            pass
210    def addTypeEqualityFunc(self, typeobj, function):
211        """Add a type specific assertEqual style function to compare a type.
213        This method is for use by TestCase subclasses that need to register
214        their own type equality functions to provide nicer error messages.
216        Args:
217            typeobj: The data type to call this function on when both values
218                    are of the same type in assertEqual().
219            function: The callable taking two arguments and an optional
220                    msg= argument that raises self.failureException with a
221                    useful error message when the two arguments are not equal.
222        """
223        self._type_equality_funcs[typeobj] = function
225    def addCleanup(self, function, *args, **kwargs):
226        """Add a function, with arguments, to be called when the test is
227        completed. Functions added are called on a LIFO basis and are
228        called after tearDown on test failure or success.
230        Cleanup items are called even if setUp fails (unlike tearDown)."""
231        self._cleanups.append((function, args, kwargs))
233    def setUp(self):
234        "Hook method for setting up the test fixture before exercising it."
235        pass
237    def tearDown(self):
238        "Hook method for deconstructing the test fixture after testing it."
239        pass
241    @classmethod
242    def setUpClass(cls):
243        "Hook method for setting up class fixture before running tests in the class."
245    @classmethod
246    def tearDownClass(cls):
247        "Hook method for deconstructing the class fixture after running all tests in the class."
249    def countTestCases(self):
250        return 1
252    def defaultTestResult(self):
253        return result.TestResult()
255    def shortDescription(self):
256        """Returns a one-line description of the test, or None if no
257        description has been provided.
259        The default implementation of this method returns the first line of
260        the specified test method's docstring.
261        """
262        doc = self._testMethodDoc
263        return doc and doc.split("\n")[0].strip() or None
266    def id(self):
267        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
269    def __eq__(self, other):
270        if type(self) is not type(other):
271            return NotImplemented
273        return self._testMethodName == other._testMethodName
275    def __ne__(self, other):
276        return not self == other
278    def __hash__(self):
279        return hash((type(self), self._testMethodName))
281    def __str__(self):
282        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
284    def __repr__(self):
285        return "<%s testMethod=%s>" % \
286               (strclass(self.__class__), self._testMethodName)
288    def _addSkip(self, result, reason):
289        addSkip = getattr(result, 'addSkip', None)
290        if addSkip is not None:
291            addSkip(self, reason)
292        else:
293            warnings.warn("TestResult has no addSkip method, skips not reported",
294                          RuntimeWarning, 2)
295            result.addSuccess(self)
297    def run(self, result=None):
298        orig_result = result
299        if result is None:
300            result = self.defaultTestResult()
301            startTestRun = getattr(result, 'startTestRun', None)
302            if startTestRun is not None:
303                startTestRun()
305        self._resultForDoCleanups = result
306        result.startTest(self)
308        testMethod = getattr(self, self._testMethodName)
309        if (getattr(self.__class__, "__unittest_skip__", False) or
310            getattr(testMethod, "__unittest_skip__", False)):
311            # If the class or method was skipped.
312            try:
313                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
314                            or getattr(testMethod, '__unittest_skip_why__', ''))
315                self._addSkip(result, skip_why)
316            finally:
317                result.stopTest(self)
318            return
319        try:
320            success = False
321            try:
322                self.setUp()
323            except SkipTest as e:
324                self._addSkip(result, str(e))
325            except KeyboardInterrupt:
326                raise
327            except:
328                result.addError(self, sys.exc_info())
329            else:
330                try:
331                    testMethod()
332                except KeyboardInterrupt:
333                    raise
334                except self.failureException:
335                    result.addFailure(self, sys.exc_info())
336                except _ExpectedFailure as e:
337                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
338                    if addExpectedFailure is not None:
339                        addExpectedFailure(self, e.exc_info)
340                    else:
341                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
342                                      RuntimeWarning)
343                        result.addSuccess(self)
344                except _UnexpectedSuccess:
345                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
346                    if addUnexpectedSuccess is not None:
347                        addUnexpectedSuccess(self)
348                    else:
349                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
350                                      RuntimeWarning)
351                        result.addFailure(self, sys.exc_info())
352                except SkipTest as e:
353                    self._addSkip(result, str(e))
354                except:
355                    result.addError(self, sys.exc_info())
356                else:
357                    success = True
359                try:
360                    self.tearDown()
361                except KeyboardInterrupt:
362                    raise
363                except:
364                    result.addError(self, sys.exc_info())
365                    success = False
367            cleanUpSuccess = self.doCleanups()
368            success = success and cleanUpSuccess
369            if success:
370                result.addSuccess(self)
371        finally:
372            result.stopTest(self)
373            if orig_result is None:
374                stopTestRun = getattr(result, 'stopTestRun', None)
375                if stopTestRun is not None:
376                    stopTestRun()
378    def doCleanups(self):
379        """Execute all cleanup functions. Normally called for you after
380        tearDown."""
381        result = self._resultForDoCleanups
382        ok = True
383        while self._cleanups:
384            function, args, kwargs = self._cleanups.pop(-1)
385            try:
386                function(*args, **kwargs)
387            except KeyboardInterrupt:
388                raise
389            except:
390                ok = False
391                result.addError(self, sys.exc_info())
392        return ok
394    def __call__(self, *args, **kwds):
395        return self.run(*args, **kwds)
397    def debug(self):
398        """Run the test without collecting errors in a TestResult"""
399        self.setUp()
400        getattr(self, self._testMethodName)()
401        self.tearDown()
402        while self._cleanups:
403            function, args, kwargs = self._cleanups.pop(-1)
404            function(*args, **kwargs)
406    def skipTest(self, reason):
407        """Skip this test."""
408        raise SkipTest(reason)
410    def fail(self, msg=None):
411        """Fail immediately, with the given message."""
412        raise self.failureException(msg)
414    def assertFalse(self, expr, msg=None):
415        """Check that the expression is false."""
416        if expr:
417            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
418            raise self.failureException(msg)
420    def assertTrue(self, expr, msg=None):
421        """Check that the expression is true."""
422        if not expr:
423            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
424            raise self.failureException(msg)
426    def _formatMessage(self, msg, standardMsg):
427        """Honour the longMessage attribute when generating failure messages.
428        If longMessage is False this means:
429        * Use only an explicit message if it is provided
430        * Otherwise use the standard message for the assert
432        If longMessage is True:
433        * Use the standard message
434        * If an explicit message is provided, plus ' : ' and the explicit message
435        """
436        if not self.longMessage:
437            return msg or standardMsg
438        if msg is None:
439            return standardMsg
440        try:
441            # don't switch to '{}' formatting in Python 2.X
442            # it changes the way unicode input is handled
443            return '%s : %s' % (standardMsg, msg)
444        except UnicodeDecodeError:
445            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
448    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
449        """Fail unless an exception of class excClass is raised
450           by callableObj when invoked with arguments args and keyword
451           arguments kwargs. If a different type of exception is
452           raised, it will not be caught, and the test case will be
453           deemed to have suffered an error, exactly as for an
454           unexpected exception.
456           If called with callableObj omitted or None, will return a
457           context object used like this::
459                with self.assertRaises(SomeException):
460                    do_something()
462           The context manager keeps a reference to the exception as
463           the 'exception' attribute. This allows you to inspect the
464           exception after the assertion::
466               with self.assertRaises(SomeException) as cm:
467                   do_something()
468               the_exception = cm.exception
469               self.assertEqual(the_exception.error_code, 3)
470        """
471        context = _AssertRaisesContext(excClass, self)
472        if callableObj is None:
473            return context
474        with context:
475            callableObj(*args, **kwargs)
477    def _getAssertEqualityFunc(self, first, second):
478        """Get a detailed comparison function for the types of the two args.
480        Returns: A callable accepting (first, second, msg=None) that will
481        raise a failure exception if first != second with a useful human
482        readable error message for those types.
483        """
484        #
485        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
486        # and vice versa.  I opted for the conservative approach in case
487        # subclasses are not intended to be compared in detail to their super
488        # class instances using a type equality func.  This means testing
489        # subtypes won't automagically use the detailed comparison.  Callers
490        # should use their type specific assertSpamEqual method to compare
491        # subclasses if the detailed comparison is desired and appropriate.
492        # See the discussion in http://bugs.python.org/issue2578.
493        #
494        if type(first) is type(second):
495            asserter = self._type_equality_funcs.get(type(first))
496            if asserter is not None:
497                if isinstance(asserter, basestring):
498                    asserter = getattr(self, asserter)
499                return asserter
501        return self._baseAssertEqual
503    def _baseAssertEqual(self, first, second, msg=None):
504        """The default assertEqual implementation, not type specific."""
505        if not first == second:
506            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
507            msg = self._formatMessage(msg, standardMsg)
508            raise self.failureException(msg)
510    def assertEqual(self, first, second, msg=None):
511        """Fail if the two objects are unequal as determined by the '=='
512           operator.
513        """
514        assertion_func = self._getAssertEqualityFunc(first, second)
515        assertion_func(first, second, msg=msg)
517    def assertNotEqual(self, first, second, msg=None):
518        """Fail if the two objects are equal as determined by the '!='
519           operator.
520        """
521        if not first != second:
522            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
523                                                          safe_repr(second)))
524            raise self.failureException(msg)
527    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
528        """Fail if the two objects are unequal as determined by their
529           difference rounded to the given number of decimal places
530           (default 7) and comparing to zero, or by comparing that the
531           between the two objects is more than the given delta.
533           Note that decimal places (from zero) are usually not the same
534           as significant digits (measured from the most signficant digit).
536           If the two objects compare equal then they will automatically
537           compare almost equal.
538        """
539        if first == second:
540            # shortcut
541            return
542        if delta is not None and places is not None:
543            raise TypeError("specify delta or places not both")
545        if delta is not None:
546            if abs(first - second) <= delta:
547                return
549            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
550                                                        safe_repr(second),
551                                                        safe_repr(delta))
552        else:
553            if places is None:
554                places = 7
556            if round(abs(second-first), places) == 0:
557                return
559            standardMsg = '%s != %s within %r places' % (safe_repr(first),
560                                                          safe_repr(second),
561                                                          places)
562        msg = self._formatMessage(msg, standardMsg)
563        raise self.failureException(msg)
565    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
566        """Fail if the two objects are equal as determined by their
567           difference rounded to the given number of decimal places
568           (default 7) and comparing to zero, or by comparing that the
569           between the two objects is less than the given delta.
571           Note that decimal places (from zero) are usually not the same
572           as significant digits (measured from the most signficant digit).
574           Objects that are equal automatically fail.
575        """
576        if delta is not None and places is not None:
577            raise TypeError("specify delta or places not both")
578        if delta is not None:
579            if not (first == second) and abs(first - second) > delta:
580                return
581            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
582                                                        safe_repr(second),
583                                                        safe_repr(delta))
584        else:
585            if places is None:
586                places = 7
587            if not (first == second) and round(abs(second-first), places) != 0:
588                return
589            standardMsg = '%s == %s within %r places' % (safe_repr(first),
590                                                         safe_repr(second),
591                                                         places)
593        msg = self._formatMessage(msg, standardMsg)
594        raise self.failureException(msg)
596    # Synonyms for assertion methods
598    # The plurals are undocumented.  Keep them that way to discourage use.
599    # Do not add more.  Do not remove.
600    # Going through a deprecation cycle on these would annoy many people.
601    assertEquals = assertEqual
602    assertNotEquals = assertNotEqual
603    assertAlmostEquals = assertAlmostEqual
604    assertNotAlmostEquals = assertNotAlmostEqual
605    assert_ = assertTrue
607    # These fail* assertion method names are pending deprecation and will
608    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
609    def _deprecate(original_func):
610        def deprecated_func(*args, **kwargs):
611            warnings.warn(
612                'Please use {0} instead.'.format(original_func.__name__),
613                PendingDeprecationWarning, 2)
614            return original_func(*args, **kwargs)
615        return deprecated_func
617    failUnlessEqual = _deprecate(assertEqual)
618    failIfEqual = _deprecate(assertNotEqual)
619    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
620    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
621    failUnless = _deprecate(assertTrue)
622    failUnlessRaises = _deprecate(assertRaises)
623    failIf = _deprecate(assertFalse)
625    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
626        """An equality assertion for ordered sequences (like lists and tuples).
628        For the purposes of this function, a valid ordered sequence type is one
629        which can be indexed, has a length, and has an equality operator.
631        Args:
632            seq1: The first sequence to compare.
633            seq2: The second sequence to compare.
634            seq_type: The expected datatype of the sequences, or None if no
635                    datatype should be enforced.
636            msg: Optional message to use on failure instead of a list of
637                    differences.
638        """
639        if seq_type is not None:
640            seq_type_name = seq_type.__name__
641            if not isinstance(seq1, seq_type):
642                raise self.failureException('First sequence is not a %s: %s'
643                                        % (seq_type_name, safe_repr(seq1)))
644            if not isinstance(seq2, seq_type):
645                raise self.failureException('Second sequence is not a %s: %s'
646                                        % (seq_type_name, safe_repr(seq2)))
647        else:
648            seq_type_name = "sequence"
650        differing = None
651        try:
652            len1 = len(seq1)
653        except (TypeError, NotImplementedError):
654            differing = 'First %s has no length.    Non-sequence?' % (
655                    seq_type_name)
657        if differing is None:
658            try:
659                len2 = len(seq2)
660            except (TypeError, NotImplementedError):
661                differing = 'Second %s has no length.    Non-sequence?' % (
662                        seq_type_name)
664        if differing is None:
665            if seq1 == seq2:
666                return
668            seq1_repr = safe_repr(seq1)
669            seq2_repr = safe_repr(seq2)
670            if len(seq1_repr) > 30:
671                seq1_repr = seq1_repr[:30] + '...'
672            if len(seq2_repr) > 30:
673                seq2_repr = seq2_repr[:30] + '...'
674            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
675            differing = '%ss differ: %s != %s\n' % elements
677            for i in xrange(min(len1, len2)):
678                try:
679                    item1 = seq1[i]
680                except (TypeError, IndexError, NotImplementedError):
681                    differing += ('\nUnable to index element %d of first %s\n' %
682                                 (i, seq_type_name))
683                    break
685                try:
686                    item2 = seq2[i]
687                except (TypeError, IndexError, NotImplementedError):
688                    differing += ('\nUnable to index element %d of second %s\n' %
689                                 (i, seq_type_name))
690                    break
692                if item1 != item2:
693                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
694                                 (i, item1, item2))
695                    break
696            else:
697                if (len1 == len2 and seq_type is None and
698                    type(seq1) != type(seq2)):
699                    # The sequences are the same, but have differing types.
700                    return
702            if len1 > len2:
703                differing += ('\nFirst %s contains %d additional '
704                             'elements.\n' % (seq_type_name, len1 - len2))
705                try:
706                    differing += ('First extra element %d:\n%s\n' %
707                                  (len2, seq1[len2]))
708                except (TypeError, IndexError, NotImplementedError):
709                    differing += ('Unable to index element %d '
710                                  'of first %s\n' % (len2, seq_type_name))
711            elif len1 < len2:
712                differing += ('\nSecond %s contains %d additional '
713                             'elements.\n' % (seq_type_name, len2 - len1))
714                try:
715                    differing += ('First extra element %d:\n%s\n' %
716                                  (len1, seq2[len1]))
717                except (TypeError, IndexError, NotImplementedError):
718                    differing += ('Unable to index element %d '
719                                  'of second %s\n' % (len1, seq_type_name))
720        standardMsg = differing
721        diffMsg = '\n' + '\n'.join(
722            difflib.ndiff(pprint.pformat(seq1).splitlines(),
723                          pprint.pformat(seq2).splitlines()))
724        standardMsg = self._truncateMessage(standardMsg, diffMsg)
725        msg = self._formatMessage(msg, standardMsg)
726        self.fail(msg)
728    def _truncateMessage(self, message, diff):
729        max_diff = self.maxDiff
730        if max_diff is None or len(diff) <= max_diff:
731            return message + diff
732        return message + (DIFF_OMITTED % len(diff))
734    def assertListEqual(self, list1, list2, msg=None):
735        """A list-specific equality assertion.
737        Args:
738            list1: The first list to compare.
739            list2: The second list to compare.
740            msg: Optional message to use on failure instead of a list of
741                    differences.
743        """
744        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
746    def assertTupleEqual(self, tuple1, tuple2, msg=None):
747        """A tuple-specific equality assertion.
749        Args:
750            tuple1: The first tuple to compare.
751            tuple2: The second tuple to compare.
752            msg: Optional message to use on failure instead of a list of
753                    differences.
754        """
755        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
757    def assertSetEqual(self, set1, set2, msg=None):
758        """A set-specific equality assertion.
760        Args:
761            set1: The first set to compare.
762            set2: The second set to compare.
763            msg: Optional message to use on failure instead of a list of
764                    differences.
766        assertSetEqual uses ducktyping to support different types of sets, and
767        is optimized for sets specifically (parameters must support a
768        difference method).
769        """
770        try:
771            difference1 = set1.difference(set2)
772        except TypeError, e:
773            self.fail('invalid type when attempting set difference: %s' % e)
774        except AttributeError, e:
775            self.fail('first argument does not support set difference: %s' % e)
777        try:
778            difference2 = set2.difference(set1)
779        except TypeError, e:
780            self.fail('invalid type when attempting set difference: %s' % e)
781        except AttributeError, e:
782            self.fail('second argument does not support set difference: %s' % e)
784        if not (difference1 or difference2):
785            return
787        lines = []
788        if difference1:
789            lines.append('Items in the first set but not the second:')
790            for item in difference1:
791                lines.append(repr(item))
792        if difference2:
793            lines.append('Items in the second set but not the first:')
794            for item in difference2:
795                lines.append(repr(item))
797        standardMsg = '\n'.join(lines)
798        self.fail(self._formatMessage(msg, standardMsg))
800    def assertIn(self, member, container, msg=None):
801        """Just like self.assertTrue(a in b), but with a nicer default message."""
802        if member not in container:
803            standardMsg = '%s not found in %s' % (safe_repr(member),
804                                                  safe_repr(container))
805            self.fail(self._formatMessage(msg, standardMsg))
807    def assertNotIn(self, member, container, msg=None):
808        """Just like self.assertTrue(a not in b), but with a nicer default message."""
809        if member in container:
810            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
811                                                        safe_repr(container))
812            self.fail(self._formatMessage(msg, standardMsg))
814    def assertIs(self, expr1, expr2, msg=None):
815        """Just like self.assertTrue(a is b), but with a nicer default message."""
816        if expr1 is not expr2:
817            standardMsg = '%s is not %s' % (safe_repr(expr1),
818                                             safe_repr(expr2))
819            self.fail(self._formatMessage(msg, standardMsg))
821    def assertIsNot(self, expr1, expr2, msg=None):
822        """Just like self.assertTrue(a is not b), but with a nicer default message."""
823        if expr1 is expr2:
824            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
825            self.fail(self._formatMessage(msg, standardMsg))
827    def assertDictEqual(self, d1, d2, msg=None):
828        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
829        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
831        if d1 != d2:
832            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
833            diff = ('\n' + '\n'.join(difflib.ndiff(
834                           pprint.pformat(d1).splitlines(),
835                           pprint.pformat(d2).splitlines())))
836            standardMsg = self._truncateMessage(standardMsg, diff)
837            self.fail(self._formatMessage(msg, standardMsg))
839    def assertDictContainsSubset(self, expected, actual, msg=None):
840        """Checks whether actual is a superset of expected."""
841        missing = []
842        mismatched = []
843        for key, value in expected.iteritems():
844            if key not in actual:
845                missing.append(key)
846            elif value != actual[key]:
847                mismatched.append('%s, expected: %s, actual: %s' %
848                                  (safe_repr(key), safe_repr(value),
849                                   safe_repr(actual[key])))
851        if not (missing or mismatched):
852            return
854        standardMsg = ''
855        if missing:
856            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
857                                                    missing)
858        if mismatched:
859            if standardMsg:
860                standardMsg += '; '
861            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
863        self.fail(self._formatMessage(msg, standardMsg))
865    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
866        """An unordered sequence specific comparison. It asserts that
867        actual_seq and expected_seq have the same element counts.
868        Equivalent to::
870            self.assertEqual(Counter(iter(actual_seq)),
871                             Counter(iter(expected_seq)))
873        Asserts that each element has the same count in both sequences.
874        Example:
875            - [0, 1, 1] and [1, 0, 1] compare equal.
876            - [0, 0, 1] and [0, 1] compare unequal.
877        """
878        first_seq, second_seq = list(expected_seq), list(actual_seq)
879        with warnings.catch_warnings():
880            if sys.py3kwarning:
881                # Silence Py3k warning raised during the sorting
882                for _msg in ["(code|dict|type) inequality comparisons",
883                             "builtin_function_or_method order comparisons",
884                             "comparing unequal types"]:
885                    warnings.filterwarnings("ignore", _msg, DeprecationWarning)
886            try:
887                first = collections.Counter(first_seq)
888                second = collections.Counter(second_seq)
889            except TypeError:
890                # Handle case with unhashable elements
891                differences = _count_diff_all_purpose(first_seq, second_seq)
892            else:
893                if first == second:
894                    return
895                differences = _count_diff_hashable(first_seq, second_seq)
897        if differences:
898            standardMsg = 'Element counts were not equal:\n'
899            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
900            diffMsg = '\n'.join(lines)
901            standardMsg = self._truncateMessage(standardMsg, diffMsg)
902            msg = self._formatMessage(msg, standardMsg)
903            self.fail(msg)
905    def assertMultiLineEqual(self, first, second, msg=None):
906        """Assert that two multi-line strings are equal."""
907        self.assertIsInstance(first, basestring,
908                'First argument is not a string')
909        self.assertIsInstance(second, basestring,
910                'Second argument is not a string')
912        if first != second:
913            # don't use difflib if the strings are too long
914            if (len(first) > self._diffThreshold or
915                len(second) > self._diffThreshold):
916                self._baseAssertEqual(first, second, msg)
917            firstlines = first.splitlines(True)
918            secondlines = second.splitlines(True)
919            if len(firstlines) == 1 and first.strip('\r\n') == first:
920                firstlines = [first + '\n']
921                secondlines = [second + '\n']
922            standardMsg = '%s != %s' % (safe_repr(first, True),
923                                        safe_repr(second, True))
924            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
925            standardMsg = self._truncateMessage(standardMsg, diff)
926            self.fail(self._formatMessage(msg, standardMsg))
928    def assertLess(self, a, b, msg=None):
929        """Just like self.assertTrue(a < b), but with a nicer default message."""
930        if not a < b:
931            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
932            self.fail(self._formatMessage(msg, standardMsg))
934    def assertLessEqual(self, a, b, msg=None):
935        """Just like self.assertTrue(a <= b), but with a nicer default message."""
936        if not a <= b:
937            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
938            self.fail(self._formatMessage(msg, standardMsg))
940    def assertGreater(self, a, b, msg=None):
941        """Just like self.assertTrue(a > b), but with a nicer default message."""
942        if not a > b:
943            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
944            self.fail(self._formatMessage(msg, standardMsg))
946    def assertGreaterEqual(self, a, b, msg=None):
947        """Just like self.assertTrue(a >= b), but with a nicer default message."""
948        if not a >= b:
949            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
950            self.fail(self._formatMessage(msg, standardMsg))
952    def assertIsNone(self, obj, msg=None):
953        """Same as self.assertTrue(obj is None), with a nicer default message."""
954        if obj is not None:
955            standardMsg = '%s is not None' % (safe_repr(obj),)
956            self.fail(self._formatMessage(msg, standardMsg))
958    def assertIsNotNone(self, obj, msg=None):
959        """Included for symmetry with assertIsNone."""
960        if obj is None:
961            standardMsg = 'unexpectedly None'
962            self.fail(self._formatMessage(msg, standardMsg))
964    def assertIsInstance(self, obj, cls, msg=None):
965        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
966        default message."""
967        if not isinstance(obj, cls):
968            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
969            self.fail(self._formatMessage(msg, standardMsg))
971    def assertNotIsInstance(self, obj, cls, msg=None):
972        """Included for symmetry with assertIsInstance."""
973        if isinstance(obj, cls):
974            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
975            self.fail(self._formatMessage(msg, standardMsg))
977    def assertRaisesRegexp(self, expected_exception, expected_regexp,
978                           callable_obj=None, *args, **kwargs):
979        """Asserts that the message in a raised exception matches a regexp.
981        Args:
982            expected_exception: Exception class expected to be raised.
983            expected_regexp: Regexp (re pattern object or string) expected
984                    to be found in error message.
985            callable_obj: Function to be called.
986            args: Extra args.
987            kwargs: Extra kwargs.
988        """
989        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
990        if callable_obj is None:
991            return context
992        with context:
993            callable_obj(*args, **kwargs)
995    def assertRegexpMatches(self, text, expected_regexp, msg=None):
996        """Fail the test unless the text matches the regular expression."""
997        if isinstance(expected_regexp, basestring):
998            expected_regexp = re.compile(expected_regexp)
999        if not expected_regexp.search(text):
1000            msg = msg or "Regexp didn't match"
1001            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1002            raise self.failureException(msg)
1004    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1005        """Fail the test if the text matches the regular expression."""
1006        if isinstance(unexpected_regexp, basestring):
1007            unexpected_regexp = re.compile(unexpected_regexp)
1008        match = unexpected_regexp.search(text)
1009        if match:
1010            msg = msg or "Regexp matched"
1011            msg = '%s: %r matches %r in %r' % (msg,
1012                                               text[match.start():match.end()],
1013                                               unexpected_regexp.pattern,
1014                                               text)
1015            raise self.failureException(msg)
1018class FunctionTestCase(TestCase):
1019    """A test case that wraps a test function.
1021    This is useful for slipping pre-existing test functions into the
1022    unittest framework. Optionally, set-up and tidy-up functions can be
1023    supplied. As with TestCase, the tidy-up ('tearDown') function will
1024    always be called if the set-up ('setUp') function ran successfully.
1025    """
1027    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1028        super(FunctionTestCase, self).__init__()
1029        self._setUpFunc = setUp
1030        self._tearDownFunc = tearDown
1031        self._testFunc = testFunc
1032        self._description = description
1034    def setUp(self):
1035        if self._setUpFunc is not None:
1036            self._setUpFunc()
1038    def tearDown(self):
1039        if self._tearDownFunc is not None:
1040            self._tearDownFunc()
1042    def runTest(self):
1043        self._testFunc()
1045    def id(self):
1046        return self._testFunc.__name__
1048    def __eq__(self, other):
1049        if not isinstance(other, self.__class__):
1050            return NotImplemented
1052        return self._setUpFunc == other._setUpFunc and \
1053               self._tearDownFunc == other._tearDownFunc and \
1054               self._testFunc == other._testFunc and \
1055               self._description == other._description
1057    def __ne__(self, other):
1058        return not self == other
1060    def __hash__(self):
1061        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1062                     self._testFunc, self._description))
1064    def __str__(self):
1065        return "%s (%s)" % (strclass(self.__class__),
1066                            self._testFunc.__name__)
1068    def __repr__(self):
1069        return "<%s tec=%s>" % (strclass(self.__class__),
1070                                     self._testFunc)
1072    def shortDescription(self):
1073        if self._description is not None:
1074            return self._description
1075        doc = self._testFunc.__doc__
1076        return doc and doc.split("\n")[0].strip() or None