1"""Test case implementation"""
2
3import collections
4import sys
5import functools
6import difflib
7import pprint
8import re
9import types
10import warnings
11
12from . import result
13from .util import (
14    strclass, safe_repr, unorderable_list_difference,
15    _count_diff_all_purpose, _count_diff_hashable
16)
17
18
19__unittest = True
20
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32    pass
33
34class _ExpectedFailure(Exception):
35    """
36    Raise this when a test is expected to fail.
37
38    This is an implementation detail.
39    """
40
41    def __init__(self, exc_info):
42        super(_ExpectedFailure, self).__init__()
43        self.exc_info = exc_info
44
45class _UnexpectedSuccess(Exception):
46    """
47    The test was supposed to fail, but it didn't!
48    """
49    pass
50
51def _id(obj):
52    return obj
53
54def skip(reason):
55    """
56    Unconditionally skip a test.
57    """
58    def decorator(test_item):
59        if not isinstance(test_item, (type, types.ClassType)):
60            @functools.wraps(test_item)
61            def skip_wrapper(*args, **kwargs):
62                raise SkipTest(reason)
63            test_item = skip_wrapper
64
65        test_item.__unittest_skip__ = True
66        test_item.__unittest_skip_why__ = reason
67        return test_item
68    return decorator
69
70def skipIf(condition, reason):
71    """
72    Skip a test if the condition is true.
73    """
74    if condition:
75        return skip(reason)
76    return _id
77
78def skipUnless(condition, reason):
79    """
80    Skip a test unless the condition is true.
81    """
82    if not condition:
83        return skip(reason)
84    return _id
85
86
87def expectedFailure(func):
88    @functools.wraps(func)
89    def wrapper(*args, **kwargs):
90        try:
91            func(*args, **kwargs)
92        except Exception:
93            raise _ExpectedFailure(sys.exc_info())
94        raise _UnexpectedSuccess
95    return wrapper
96
97
98class _AssertRaisesContext(object):
99    """A context manager used to implement TestCase.assertRaises* methods."""
100
101    def __init__(self, expected, test_case, expected_regexp=None):
102        self.expected = expected
103        self.failureException = test_case.failureException
104        self.expected_regexp = expected_regexp
105
106    def __enter__(self):
107        return self
108
109    def __exit__(self, exc_type, exc_value, tb):
110        if exc_type is None:
111            try:
112                exc_name = self.expected.__name__
113            except AttributeError:
114                exc_name = str(self.expected)
115            raise self.failureException(
116                "{0} not raised".format(exc_name))
117        if not issubclass(exc_type, self.expected):
118            # let unexpected exceptions pass through
119            return False
120        self.exception = exc_value # store for later retrieval
121        if self.expected_regexp is None:
122            return True
123
124        expected_regexp = self.expected_regexp
125        if isinstance(expected_regexp, basestring):
126            expected_regexp = re.compile(expected_regexp)
127        if not expected_regexp.search(str(exc_value)):
128            raise self.failureException('"%s" does not match "%s"' %
129                     (expected_regexp.pattern, str(exc_value)))
130        return True
131
132
133class TestCase(object):
134    """A class whose instances are single test cases.
135
136    By default, the test code itself should be placed in a method named
137    'runTest'.
138
139    If the fixture may be used for many test cases, create as
140    many test methods as are needed. When instantiating such a TestCase
141    subclass, specify in the constructor arguments the name of the test method
142    that the instance is to execute.
143
144    Test authors should subclass TestCase for their own tests. Construction
145    and deconstruction of the test's environment ('fixture') can be
146    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
147
148    If it is necessary to override the __init__ method, the base class
149    __init__ method must always be called. It is important that subclasses
150    should not change the signature of their __init__ method, since instances
151    of the classes are instantiated automatically by parts of the framework
152    in order to be run.
153
154    When subclassing TestCase, you can set these attributes:
155    * failureException: determines which exception will be raised when
156        the instance's assertion methods fail; test methods raising this
157        exception will be deemed to have 'failed' rather than 'errored'.
158    * longMessage: determines whether long messages (including repr of
159        objects used in assert methods) will be printed on failure in *addition*
160        to any explicit message passed.
161    * maxDiff: sets the maximum length of a diff in failure messages
162        by assert methods using difflib. It is looked up as an instance
163        attribute so can be configured by individual tests if required.
164    """
165
166    failureException = AssertionError
167
168    longMessage = False
169
170    maxDiff = 80*8
171
172    # If a string is longer than _diffThreshold, use normal comparison instead
173    # of difflib.  See #11763.
174    _diffThreshold = 2**16
175
176    # Attribute used by TestSuite for classSetUp
177
178    _classSetupFailed = False
179
180    def __init__(self, methodName='runTest'):
181        """Create an instance of the class that will use the named test
182           method when executed. Raises a ValueError if the instance does
183           not have a method with the specified name.
184        """
185        self._testMethodName = methodName
186        self._resultForDoCleanups = None
187        try:
188            testMethod = getattr(self, methodName)
189        except AttributeError:
190            raise ValueError("no such test method in %s: %s" %
191                  (self.__class__, methodName))
192        self._testMethodDoc = testMethod.__doc__
193        self._cleanups = []
194
195        # Map types to custom assertEqual functions that will compare
196        # instances of said type in more detail to generate a more useful
197        # error message.
198        self._type_equality_funcs = {}
199        self.addTypeEqualityFunc(dict, 'assertDictEqual')
200        self.addTypeEqualityFunc(list, 'assertListEqual')
201        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
202        self.addTypeEqualityFunc(set, 'assertSetEqual')
203        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
204        try:
205            self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
206        except NameError:
207            # No unicode support in this build
208            pass
209
210    def addTypeEqualityFunc(self, typeobj, function):
211        """Add a type specific assertEqual style function to compare a type.
212
213        This method is for use by TestCase subclasses that need to register
214        their own type equality functions to provide nicer error messages.
215
216        Args:
217            typeobj: The data type to call this function on when both values
218                    are of the same type in assertEqual().
219            function: The callable taking two arguments and an optional
220                    msg= argument that raises self.failureException with a
221                    useful error message when the two arguments are not equal.
222        """
223        self._type_equality_funcs[typeobj] = function
224
225    def addCleanup(self, function, *args, **kwargs):
226        """Add a function, with arguments, to be called when the test is
227        completed. Functions added are called on a LIFO basis and are
228        called after tearDown on test failure or success.
229
230        Cleanup items are called even if setUp fails (unlike tearDown)."""
231        self._cleanups.append((function, args, kwargs))
232
233    def setUp(self):
234        "Hook method for setting up the test fixture before exercising it."
235        pass
236
237    def tearDown(self):
238        "Hook method for deconstructing the test fixture after testing it."
239        pass
240
241    @classmethod
242    def setUpClass(cls):
243        "Hook method for setting up class fixture before running tests in the class."
244
245    @classmethod
246    def tearDownClass(cls):
247        "Hook method for deconstructing the class fixture after running all tests in the class."
248
249    def countTestCases(self):
250        return 1
251
252    def defaultTestResult(self):
253        return result.TestResult()
254
255    def shortDescription(self):
256        """Returns a one-line description of the test, or None if no
257        description has been provided.
258
259        The default implementation of this method returns the first line of
260        the specified test method's docstring.
261        """
262        doc = self._testMethodDoc
263        return doc and doc.split("\n")[0].strip() or None
264
265
266    def id(self):
267        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
268
269    def __eq__(self, other):
270        if type(self) is not type(other):
271            return NotImplemented
272
273        return self._testMethodName == other._testMethodName
274
275    def __ne__(self, other):
276        return not self == other
277
278    def __hash__(self):
279        return hash((type(self), self._testMethodName))
280
281    def __str__(self):
282        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
283
284    def __repr__(self):
285        return "<%s testMethod=%s>" % \
286               (strclass(self.__class__), self._testMethodName)
287
288    def _addSkip(self, result, reason):
289        addSkip = getattr(result, 'addSkip', None)
290        if addSkip is not None:
291            addSkip(self, reason)
292        else:
293            warnings.warn("TestResult has no addSkip method, skips not reported",
294                          RuntimeWarning, 2)
295            result.addSuccess(self)
296
297    def run(self, result=None):
298        orig_result = result
299        if result is None:
300            result = self.defaultTestResult()
301            startTestRun = getattr(result, 'startTestRun', None)
302            if startTestRun is not None:
303                startTestRun()
304
305        self._resultForDoCleanups = result
306        result.startTest(self)
307
308        testMethod = getattr(self, self._testMethodName)
309        if (getattr(self.__class__, "__unittest_skip__", False) or
310            getattr(testMethod, "__unittest_skip__", False)):
311            # If the class or method was skipped.
312            try:
313                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
314                            or getattr(testMethod, '__unittest_skip_why__', ''))
315                self._addSkip(result, skip_why)
316            finally:
317                result.stopTest(self)
318            return
319        try:
320            success = False
321            try:
322                self.setUp()
323            except SkipTest as e:
324                self._addSkip(result, str(e))
325            except KeyboardInterrupt:
326                raise
327            except:
328                result.addError(self, sys.exc_info())
329            else:
330                try:
331                    testMethod()
332                except KeyboardInterrupt:
333                    raise
334                except self.failureException:
335                    result.addFailure(self, sys.exc_info())
336                except _ExpectedFailure as e:
337                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
338                    if addExpectedFailure is not None:
339                        addExpectedFailure(self, e.exc_info)
340                    else:
341                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
342                                      RuntimeWarning)
343                        result.addSuccess(self)
344                except _UnexpectedSuccess:
345                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
346                    if addUnexpectedSuccess is not None:
347                        addUnexpectedSuccess(self)
348                    else:
349                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
350                                      RuntimeWarning)
351                        result.addFailure(self, sys.exc_info())
352                except SkipTest as e:
353                    self._addSkip(result, str(e))
354                except:
355                    result.addError(self, sys.exc_info())
356                else:
357                    success = True
358
359                try:
360                    self.tearDown()
361                except KeyboardInterrupt:
362                    raise
363                except:
364                    result.addError(self, sys.exc_info())
365                    success = False
366
367            cleanUpSuccess = self.doCleanups()
368            success = success and cleanUpSuccess
369            if success:
370                result.addSuccess(self)
371        finally:
372            result.stopTest(self)
373            if orig_result is None:
374                stopTestRun = getattr(result, 'stopTestRun', None)
375                if stopTestRun is not None:
376                    stopTestRun()
377
378    def doCleanups(self):
379        """Execute all cleanup functions. Normally called for you after
380        tearDown."""
381        result = self._resultForDoCleanups
382        ok = True
383        while self._cleanups:
384            function, args, kwargs = self._cleanups.pop(-1)
385            try:
386                function(*args, **kwargs)
387            except KeyboardInterrupt:
388                raise
389            except:
390                ok = False
391                result.addError(self, sys.exc_info())
392        return ok
393
394    def __call__(self, *args, **kwds):
395        return self.run(*args, **kwds)
396
397    def debug(self):
398        """Run the test without collecting errors in a TestResult"""
399        self.setUp()
400        getattr(self, self._testMethodName)()
401        self.tearDown()
402        while self._cleanups:
403            function, args, kwargs = self._cleanups.pop(-1)
404            function(*args, **kwargs)
405
406    def skipTest(self, reason):
407        """Skip this test."""
408        raise SkipTest(reason)
409
410    def fail(self, msg=None):
411        """Fail immediately, with the given message."""
412        raise self.failureException(msg)
413
414    def assertFalse(self, expr, msg=None):
415        """Check that the expression is false."""
416        if expr:
417            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
418            raise self.failureException(msg)
419
420    def assertTrue(self, expr, msg=None):
421        """Check that the expression is true."""
422        if not expr:
423            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
424            raise self.failureException(msg)
425
426    def _formatMessage(self, msg, standardMsg):
427        """Honour the longMessage attribute when generating failure messages.
428        If longMessage is False this means:
429        * Use only an explicit message if it is provided
430        * Otherwise use the standard message for the assert
431
432        If longMessage is True:
433        * Use the standard message
434        * If an explicit message is provided, plus ' : ' and the explicit message
435        """
436        if not self.longMessage:
437            return msg or standardMsg
438        if msg is None:
439            return standardMsg
440        try:
441            # don't switch to '{}' formatting in Python 2.X
442            # it changes the way unicode input is handled
443            return '%s : %s' % (standardMsg, msg)
444        except UnicodeDecodeError:
445            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
446
447
448    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
449        """Fail unless an exception of class excClass is raised
450           by callableObj when invoked with arguments args and keyword
451           arguments kwargs. If a different type of exception is
452           raised, it will not be caught, and the test case will be
453           deemed to have suffered an error, exactly as for an
454           unexpected exception.
455
456           If called with callableObj omitted or None, will return a
457           context object used like this::
458
459                with self.assertRaises(SomeException):
460                    do_something()
461
462           The context manager keeps a reference to the exception as
463           the 'exception' attribute. This allows you to inspect the
464           exception after the assertion::
465
466               with self.assertRaises(SomeException) as cm:
467                   do_something()
468               the_exception = cm.exception
469               self.assertEqual(the_exception.error_code, 3)
470        """
471        context = _AssertRaisesContext(excClass, self)
472        if callableObj is None:
473            return context
474        with context:
475            callableObj(*args, **kwargs)
476
477    def _getAssertEqualityFunc(self, first, second):
478        """Get a detailed comparison function for the types of the two args.
479
480        Returns: A callable accepting (first, second, msg=None) that will
481        raise a failure exception if first != second with a useful human
482        readable error message for those types.
483        """
484        #
485        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
486        # and vice versa.  I opted for the conservative approach in case
487        # subclasses are not intended to be compared in detail to their super
488        # class instances using a type equality func.  This means testing
489        # subtypes won't automagically use the detailed comparison.  Callers
490        # should use their type specific assertSpamEqual method to compare
491        # subclasses if the detailed comparison is desired and appropriate.
492        # See the discussion in http://bugs.python.org/issue2578.
493        #
494        if type(first) is type(second):
495            asserter = self._type_equality_funcs.get(type(first))
496            if asserter is not None:
497                if isinstance(asserter, basestring):
498                    asserter = getattr(self, asserter)
499                return asserter
500
501        return self._baseAssertEqual
502
503    def _baseAssertEqual(self, first, second, msg=None):
504        """The default assertEqual implementation, not type specific."""
505        if not first == second:
506            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
507            msg = self._formatMessage(msg, standardMsg)
508            raise self.failureException(msg)
509
510    def assertEqual(self, first, second, msg=None):
511        """Fail if the two objects are unequal as determined by the '=='
512           operator.
513        """
514        assertion_func = self._getAssertEqualityFunc(first, second)
515        assertion_func(first, second, msg=msg)
516
517    def assertNotEqual(self, first, second, msg=None):
518        """Fail if the two objects are equal as determined by the '!='
519           operator.
520        """
521        if not first != second:
522            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
523                                                          safe_repr(second)))
524            raise self.failureException(msg)
525
526
527    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
528        """Fail if the two objects are unequal as determined by their
529           difference rounded to the given number of decimal places
530           (default 7) and comparing to zero, or by comparing that the
531           between the two objects is more than the given delta.
532
533           Note that decimal places (from zero) are usually not the same
534           as significant digits (measured from the most signficant digit).
535
536           If the two objects compare equal then they will automatically
537           compare almost equal.
538        """
539        if first == second:
540            # shortcut
541            return
542        if delta is not None and places is not None:
543            raise TypeError("specify delta or places not both")
544
545        if delta is not None:
546            if abs(first - second) <= delta:
547                return
548
549            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
550                                                        safe_repr(second),
551                                                        safe_repr(delta))
552        else:
553            if places is None:
554                places = 7
555
556            if round(abs(second-first), places) == 0:
557                return
558
559            standardMsg = '%s != %s within %r places' % (safe_repr(first),
560                                                          safe_repr(second),
561                                                          places)
562        msg = self._formatMessage(msg, standardMsg)
563        raise self.failureException(msg)
564
565    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
566        """Fail if the two objects are equal as determined by their
567           difference rounded to the given number of decimal places
568           (default 7) and comparing to zero, or by comparing that the
569           between the two objects is less than the given delta.
570
571           Note that decimal places (from zero) are usually not the same
572           as significant digits (measured from the most signficant digit).
573
574           Objects that are equal automatically fail.
575        """
576        if delta is not None and places is not None:
577            raise TypeError("specify delta or places not both")
578        if delta is not None:
579            if not (first == second) and abs(first - second) > delta:
580                return
581            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
582                                                        safe_repr(second),
583                                                        safe_repr(delta))
584        else:
585            if places is None:
586                places = 7
587            if not (first == second) and round(abs(second-first), places) != 0:
588                return
589            standardMsg = '%s == %s within %r places' % (safe_repr(first),
590                                                         safe_repr(second),
591                                                         places)
592
593        msg = self._formatMessage(msg, standardMsg)
594        raise self.failureException(msg)
595
596    # Synonyms for assertion methods
597
598    # The plurals are undocumented.  Keep them that way to discourage use.
599    # Do not add more.  Do not remove.
600    # Going through a deprecation cycle on these would annoy many people.
601    assertEquals = assertEqual
602    assertNotEquals = assertNotEqual
603    assertAlmostEquals = assertAlmostEqual
604    assertNotAlmostEquals = assertNotAlmostEqual
605    assert_ = assertTrue
606
607    # These fail* assertion method names are pending deprecation and will
608    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
609    def _deprecate(original_func):
610        def deprecated_func(*args, **kwargs):
611            warnings.warn(
612                'Please use {0} instead.'.format(original_func.__name__),
613                PendingDeprecationWarning, 2)
614            return original_func(*args, **kwargs)
615        return deprecated_func
616
617    failUnlessEqual = _deprecate(assertEqual)
618    failIfEqual = _deprecate(assertNotEqual)
619    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
620    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
621    failUnless = _deprecate(assertTrue)
622    failUnlessRaises = _deprecate(assertRaises)
623    failIf = _deprecate(assertFalse)
624
625    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
626        """An equality assertion for ordered sequences (like lists and tuples).
627
628        For the purposes of this function, a valid ordered sequence type is one
629        which can be indexed, has a length, and has an equality operator.
630
631        Args:
632            seq1: The first sequence to compare.
633            seq2: The second sequence to compare.
634            seq_type: The expected datatype of the sequences, or None if no
635                    datatype should be enforced.
636            msg: Optional message to use on failure instead of a list of
637                    differences.
638        """
639        if seq_type is not None:
640            seq_type_name = seq_type.__name__
641            if not isinstance(seq1, seq_type):
642                raise self.failureException('First sequence is not a %s: %s'
643                                        % (seq_type_name, safe_repr(seq1)))
644            if not isinstance(seq2, seq_type):
645                raise self.failureException('Second sequence is not a %s: %s'
646                                        % (seq_type_name, safe_repr(seq2)))
647        else:
648            seq_type_name = "sequence"
649
650        differing = None
651        try:
652            len1 = len(seq1)
653        except (TypeError, NotImplementedError):
654            differing = 'First %s has no length.    Non-sequence?' % (
655                    seq_type_name)
656
657        if differing is None:
658            try:
659                len2 = len(seq2)
660            except (TypeError, NotImplementedError):
661                differing = 'Second %s has no length.    Non-sequence?' % (
662                        seq_type_name)
663
664        if differing is None:
665            if seq1 == seq2:
666                return
667
668            seq1_repr = safe_repr(seq1)
669            seq2_repr = safe_repr(seq2)
670            if len(seq1_repr) > 30:
671                seq1_repr = seq1_repr[:30] + '...'
672            if len(seq2_repr) > 30:
673                seq2_repr = seq2_repr[:30] + '...'
674            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
675            differing = '%ss differ: %s != %s\n' % elements
676
677            for i in xrange(min(len1, len2)):
678                try:
679                    item1 = seq1[i]
680                except (TypeError, IndexError, NotImplementedError):
681                    differing += ('\nUnable to index element %d of first %s\n' %
682                                 (i, seq_type_name))
683                    break
684
685                try:
686                    item2 = seq2[i]
687                except (TypeError, IndexError, NotImplementedError):
688                    differing += ('\nUnable to index element %d of second %s\n' %
689                                 (i, seq_type_name))
690                    break
691
692                if item1 != item2:
693                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
694                                 (i, item1, item2))
695                    break
696            else:
697                if (len1 == len2 and seq_type is None and
698                    type(seq1) != type(seq2)):
699                    # The sequences are the same, but have differing types.
700                    return
701
702            if len1 > len2:
703                differing += ('\nFirst %s contains %d additional '
704                             'elements.\n' % (seq_type_name, len1 - len2))
705                try:
706                    differing += ('First extra element %d:\n%s\n' %
707                                  (len2, seq1[len2]))
708                except (TypeError, IndexError, NotImplementedError):
709                    differing += ('Unable to index element %d '
710                                  'of first %s\n' % (len2, seq_type_name))
711            elif len1 < len2:
712                differing += ('\nSecond %s contains %d additional '
713                             'elements.\n' % (seq_type_name, len2 - len1))
714                try:
715                    differing += ('First extra element %d:\n%s\n' %
716                                  (len1, seq2[len1]))
717                except (TypeError, IndexError, NotImplementedError):
718                    differing += ('Unable to index element %d '
719                                  'of second %s\n' % (len1, seq_type_name))
720        standardMsg = differing
721        diffMsg = '\n' + '\n'.join(
722            difflib.ndiff(pprint.pformat(seq1).splitlines(),
723                          pprint.pformat(seq2).splitlines()))
724        standardMsg = self._truncateMessage(standardMsg, diffMsg)
725        msg = self._formatMessage(msg, standardMsg)
726        self.fail(msg)
727
728    def _truncateMessage(self, message, diff):
729        max_diff = self.maxDiff
730        if max_diff is None or len(diff) <= max_diff:
731            return message + diff
732        return message + (DIFF_OMITTED % len(diff))
733
734    def assertListEqual(self, list1, list2, msg=None):
735        """A list-specific equality assertion.
736
737        Args:
738            list1: The first list to compare.
739            list2: The second list to compare.
740            msg: Optional message to use on failure instead of a list of
741                    differences.
742
743        """
744        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
745
746    def assertTupleEqual(self, tuple1, tuple2, msg=None):
747        """A tuple-specific equality assertion.
748
749        Args:
750            tuple1: The first tuple to compare.
751            tuple2: The second tuple to compare.
752            msg: Optional message to use on failure instead of a list of
753                    differences.
754        """
755        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
756
757    def assertSetEqual(self, set1, set2, msg=None):
758        """A set-specific equality assertion.
759
760        Args:
761            set1: The first set to compare.
762            set2: The second set to compare.
763            msg: Optional message to use on failure instead of a list of
764                    differences.
765
766        assertSetEqual uses ducktyping to support different types of sets, and
767        is optimized for sets specifically (parameters must support a
768        difference method).
769        """
770        try:
771            difference1 = set1.difference(set2)
772        except TypeError, e:
773            self.fail('invalid type when attempting set difference: %s' % e)
774        except AttributeError, e:
775            self.fail('first argument does not support set difference: %s' % e)
776
777        try:
778            difference2 = set2.difference(set1)
779        except TypeError, e:
780            self.fail('invalid type when attempting set difference: %s' % e)
781        except AttributeError, e:
782            self.fail('second argument does not support set difference: %s' % e)
783
784        if not (difference1 or difference2):
785            return
786
787        lines = []
788        if difference1:
789            lines.append('Items in the first set but not the second:')
790            for item in difference1:
791                lines.append(repr(item))
792        if difference2:
793            lines.append('Items in the second set but not the first:')
794            for item in difference2:
795                lines.append(repr(item))
796
797        standardMsg = '\n'.join(lines)
798        self.fail(self._formatMessage(msg, standardMsg))
799
800    def assertIn(self, member, container, msg=None):
801        """Just like self.assertTrue(a in b), but with a nicer default message."""
802        if member not in container:
803            standardMsg = '%s not found in %s' % (safe_repr(member),
804                                                  safe_repr(container))
805            self.fail(self._formatMessage(msg, standardMsg))
806
807    def assertNotIn(self, member, container, msg=None):
808        """Just like self.assertTrue(a not in b), but with a nicer default message."""
809        if member in container:
810            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
811                                                        safe_repr(container))
812            self.fail(self._formatMessage(msg, standardMsg))
813
814    def assertIs(self, expr1, expr2, msg=None):
815        """Just like self.assertTrue(a is b), but with a nicer default message."""
816        if expr1 is not expr2:
817            standardMsg = '%s is not %s' % (safe_repr(expr1),
818                                             safe_repr(expr2))
819            self.fail(self._formatMessage(msg, standardMsg))
820
821    def assertIsNot(self, expr1, expr2, msg=None):
822        """Just like self.assertTrue(a is not b), but with a nicer default message."""
823        if expr1 is expr2:
824            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
825            self.fail(self._formatMessage(msg, standardMsg))
826
827    def assertDictEqual(self, d1, d2, msg=None):
828        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
829        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
830
831        if d1 != d2:
832            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
833            diff = ('\n' + '\n'.join(difflib.ndiff(
834                           pprint.pformat(d1).splitlines(),
835                           pprint.pformat(d2).splitlines())))
836            standardMsg = self._truncateMessage(standardMsg, diff)
837            self.fail(self._formatMessage(msg, standardMsg))
838
839    def assertDictContainsSubset(self, expected, actual, msg=None):
840        """Checks whether actual is a superset of expected."""
841        missing = []
842        mismatched = []
843        for key, value in expected.iteritems():
844            if key not in actual:
845                missing.append(key)
846            elif value != actual[key]:
847                mismatched.append('%s, expected: %s, actual: %s' %
848                                  (safe_repr(key), safe_repr(value),
849                                   safe_repr(actual[key])))
850
851        if not (missing or mismatched):
852            return
853
854        standardMsg = ''
855        if missing:
856            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
857                                                    missing)
858        if mismatched:
859            if standardMsg:
860                standardMsg += '; '
861            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
862
863        self.fail(self._formatMessage(msg, standardMsg))
864
865    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
866        """An unordered sequence specific comparison. It asserts that
867        actual_seq and expected_seq have the same element counts.
868        Equivalent to::
869
870            self.assertEqual(Counter(iter(actual_seq)),
871                             Counter(iter(expected_seq)))
872
873        Asserts that each element has the same count in both sequences.
874        Example:
875            - [0, 1, 1] and [1, 0, 1] compare equal.
876            - [0, 0, 1] and [0, 1] compare unequal.
877        """
878        first_seq, second_seq = list(expected_seq), list(actual_seq)
879        with warnings.catch_warnings():
880            if sys.py3kwarning:
881                # Silence Py3k warning raised during the sorting
882                for _msg in ["(code|dict|type) inequality comparisons",
883                             "builtin_function_or_method order comparisons",
884                             "comparing unequal types"]:
885                    warnings.filterwarnings("ignore", _msg, DeprecationWarning)
886            try:
887                first = collections.Counter(first_seq)
888                second = collections.Counter(second_seq)
889            except TypeError:
890                # Handle case with unhashable elements
891                differences = _count_diff_all_purpose(first_seq, second_seq)
892            else:
893                if first == second:
894                    return
895                differences = _count_diff_hashable(first_seq, second_seq)
896
897        if differences:
898            standardMsg = 'Element counts were not equal:\n'
899            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
900            diffMsg = '\n'.join(lines)
901            standardMsg = self._truncateMessage(standardMsg, diffMsg)
902            msg = self._formatMessage(msg, standardMsg)
903            self.fail(msg)
904
905    def assertMultiLineEqual(self, first, second, msg=None):
906        """Assert that two multi-line strings are equal."""
907        self.assertIsInstance(first, basestring,
908                'First argument is not a string')
909        self.assertIsInstance(second, basestring,
910                'Second argument is not a string')
911
912        if first != second:
913            # don't use difflib if the strings are too long
914            if (len(first) > self._diffThreshold or
915                len(second) > self._diffThreshold):
916                self._baseAssertEqual(first, second, msg)
917            firstlines = first.splitlines(True)
918            secondlines = second.splitlines(True)
919            if len(firstlines) == 1 and first.strip('\r\n') == first:
920                firstlines = [first + '\n']
921                secondlines = [second + '\n']
922            standardMsg = '%s != %s' % (safe_repr(first, True),
923                                        safe_repr(second, True))
924            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
925            standardMsg = self._truncateMessage(standardMsg, diff)
926            self.fail(self._formatMessage(msg, standardMsg))
927
928    def assertLess(self, a, b, msg=None):
929        """Just like self.assertTrue(a < b), but with a nicer default message."""
930        if not a < b:
931            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
932            self.fail(self._formatMessage(msg, standardMsg))
933
934    def assertLessEqual(self, a, b, msg=None):
935        """Just like self.assertTrue(a <= b), but with a nicer default message."""
936        if not a <= b:
937            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
938            self.fail(self._formatMessage(msg, standardMsg))
939
940    def assertGreater(self, a, b, msg=None):
941        """Just like self.assertTrue(a > b), but with a nicer default message."""
942        if not a > b:
943            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
944            self.fail(self._formatMessage(msg, standardMsg))
945
946    def assertGreaterEqual(self, a, b, msg=None):
947        """Just like self.assertTrue(a >= b), but with a nicer default message."""
948        if not a >= b:
949            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
950            self.fail(self._formatMessage(msg, standardMsg))
951
952    def assertIsNone(self, obj, msg=None):
953        """Same as self.assertTrue(obj is None), with a nicer default message."""
954        if obj is not None:
955            standardMsg = '%s is not None' % (safe_repr(obj),)
956            self.fail(self._formatMessage(msg, standardMsg))
957
958    def assertIsNotNone(self, obj, msg=None):
959        """Included for symmetry with assertIsNone."""
960        if obj is None:
961            standardMsg = 'unexpectedly None'
962            self.fail(self._formatMessage(msg, standardMsg))
963
964    def assertIsInstance(self, obj, cls, msg=None):
965        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
966        default message."""
967        if not isinstance(obj, cls):
968            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
969            self.fail(self._formatMessage(msg, standardMsg))
970
971    def assertNotIsInstance(self, obj, cls, msg=None):
972        """Included for symmetry with assertIsInstance."""
973        if isinstance(obj, cls):
974            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
975            self.fail(self._formatMessage(msg, standardMsg))
976
977    def assertRaisesRegexp(self, expected_exception, expected_regexp,
978                           callable_obj=None, *args, **kwargs):
979        """Asserts that the message in a raised exception matches a regexp.
980
981        Args:
982            expected_exception: Exception class expected to be raised.
983            expected_regexp: Regexp (re pattern object or string) expected
984                    to be found in error message.
985            callable_obj: Function to be called.
986            args: Extra args.
987            kwargs: Extra kwargs.
988        """
989        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
990        if callable_obj is None:
991            return context
992        with context:
993            callable_obj(*args, **kwargs)
994
995    def assertRegexpMatches(self, text, expected_regexp, msg=None):
996        """Fail the test unless the text matches the regular expression."""
997        if isinstance(expected_regexp, basestring):
998            expected_regexp = re.compile(expected_regexp)
999        if not expected_regexp.search(text):
1000            msg = msg or "Regexp didn't match"
1001            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1002            raise self.failureException(msg)
1003
1004    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1005        """Fail the test if the text matches the regular expression."""
1006        if isinstance(unexpected_regexp, basestring):
1007            unexpected_regexp = re.compile(unexpected_regexp)
1008        match = unexpected_regexp.search(text)
1009        if match:
1010            msg = msg or "Regexp matched"
1011            msg = '%s: %r matches %r in %r' % (msg,
1012                                               text[match.start():match.end()],
1013                                               unexpected_regexp.pattern,
1014                                               text)
1015            raise self.failureException(msg)
1016
1017
1018class FunctionTestCase(TestCase):
1019    """A test case that wraps a test function.
1020
1021    This is useful for slipping pre-existing test functions into the
1022    unittest framework. Optionally, set-up and tidy-up functions can be
1023    supplied. As with TestCase, the tidy-up ('tearDown') function will
1024    always be called if the set-up ('setUp') function ran successfully.
1025    """
1026
1027    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1028        super(FunctionTestCase, self).__init__()
1029        self._setUpFunc = setUp
1030        self._tearDownFunc = tearDown
1031        self._testFunc = testFunc
1032        self._description = description
1033
1034    def setUp(self):
1035        if self._setUpFunc is not None:
1036            self._setUpFunc()
1037
1038    def tearDown(self):
1039        if self._tearDownFunc is not None:
1040            self._tearDownFunc()
1041
1042    def runTest(self):
1043        self._testFunc()
1044
1045    def id(self):
1046        return self._testFunc.__name__
1047
1048    def __eq__(self, other):
1049        if not isinstance(other, self.__class__):
1050            return NotImplemented
1051
1052        return self._setUpFunc == other._setUpFunc and \
1053               self._tearDownFunc == other._tearDownFunc and \
1054               self._testFunc == other._testFunc and \
1055               self._description == other._description
1056
1057    def __ne__(self, other):
1058        return not self == other
1059
1060    def __hash__(self):
1061        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1062                     self._testFunc, self._description))
1063
1064    def __str__(self):
1065        return "%s (%s)" % (strclass(self.__class__),
1066                            self._testFunc.__name__)
1067
1068    def __repr__(self):
1069        return "<%s tec=%s>" % (strclass(self.__class__),
1070                                     self._testFunc)
1071
1072    def shortDescription(self):
1073        if self._description is not None:
1074            return self._description
1075        doc = self._testFunc.__doc__
1076        return doc and doc.split("\n")[0].strip() or None
1077