1"""Test case implementation"""
2
3import sys
4import difflib
5import pprint
6import re
7import unittest
8import warnings
9
10from unittest2 import result
11from unittest2.util import (
12    safe_repr, safe_str, strclass,
13    unorderable_list_difference
14)
15
16from unittest2.compatibility import wraps
17
18__unittest = True
19
20
21DIFF_OMITTED = ('\nDiff is %s characters long. '
22                 'Set self.maxDiff to None to see it.')
23
24class SkipTest(Exception):
25    """
26    Raise this exception in a test to skip it.
27
28    Usually you can use TestResult.skip() or one of the skipping decorators
29    instead of raising this directly.
30    """
31
32class _ExpectedFailure(Exception):
33    """
34    Raise this when a test is expected to fail.
35
36    This is an implementation detail.
37    """
38
39    def __init__(self, exc_info, bugnumber=None):
40        # can't use super because Python 2.4 exceptions are old style
41        Exception.__init__(self)
42        self.exc_info = exc_info
43        self.bugnumber = bugnumber
44
45class _UnexpectedSuccess(Exception):
46    """
47    The test was supposed to fail, but it didn't!
48    """
49
50    def __init__(self, exc_info, bugnumber=None):
51        # can't use super because Python 2.4 exceptions are old style
52        Exception.__init__(self)
53        self.exc_info = exc_info
54        self.bugnumber = bugnumber
55
56def _id(obj):
57    return obj
58
59def skip(reason):
60    """
61    Unconditionally skip a test.
62    """
63    def decorator(test_item):
64        if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
65            @wraps(test_item)
66            def skip_wrapper(*args, **kwargs):
67                raise SkipTest(reason)
68            test_item = skip_wrapper
69
70        test_item.__unittest_skip__ = True
71        test_item.__unittest_skip_why__ = reason
72        return test_item
73    return decorator
74
75def skipIf(condition, reason):
76    """
77    Skip a test if the condition is true.
78    """
79    if condition:
80        return skip(reason)
81    return _id
82
83def skipUnless(condition, reason):
84    """
85    Skip a test unless the condition is true.
86    """
87    if not condition:
88        return skip(reason)
89    return _id
90
91def expectedFailure(bugnumber=None):
92     if callable(bugnumber):
93        @wraps(bugnumber)
94        def expectedFailure_easy_wrapper(*args, **kwargs):
95             try:
96                bugnumber(*args, **kwargs)
97             except Exception:
98                raise _ExpectedFailure(sys.exc_info(),None)
99             raise _UnexpectedSuccess(sys.exc_info(),None)
100        return expectedFailure_easy_wrapper
101     else:
102        def expectedFailure_impl(func):
103              @wraps(func)
104              def wrapper(*args, **kwargs):
105                   try:
106                      func(*args, **kwargs)
107                   except Exception:
108                      raise _ExpectedFailure(sys.exc_info(),bugnumber)
109                   raise _UnexpectedSuccess(sys.exc_info(),bugnumber)
110              return wrapper
111        return expectedFailure_impl
112
113class _AssertRaisesContext(object):
114    """A context manager used to implement TestCase.assertRaises* methods."""
115
116    def __init__(self, expected, test_case, expected_regexp=None):
117        self.expected = expected
118        self.failureException = test_case.failureException
119        self.expected_regexp = expected_regexp
120
121    def __enter__(self):
122        return self
123
124    def __exit__(self, exc_type, exc_value, tb):
125        if exc_type is None:
126            try:
127                exc_name = self.expected.__name__
128            except AttributeError:
129                exc_name = str(self.expected)
130            raise self.failureException(
131                "%s not raised" % (exc_name,))
132        if not issubclass(exc_type, self.expected):
133            # let unexpected exceptions pass through
134            return False
135        self.exception = exc_value # store for later retrieval
136        if self.expected_regexp is None:
137            return True
138
139        expected_regexp = self.expected_regexp
140        if isinstance(expected_regexp, basestring):
141            expected_regexp = re.compile(expected_regexp)
142        if not expected_regexp.search(str(exc_value)):
143            raise self.failureException('"%s" does not match "%s"' %
144                     (expected_regexp.pattern, str(exc_value)))
145        return True
146
147
148class _TypeEqualityDict(object):
149
150    def __init__(self, testcase):
151        self.testcase = testcase
152        self._store = {}
153
154    def __setitem__(self, key, value):
155        self._store[key] = value
156
157    def __getitem__(self, key):
158        value = self._store[key]
159        if isinstance(value, basestring):
160            return getattr(self.testcase, value)
161        return value
162
163    def get(self, key, default=None):
164        if key in self._store:
165            return self[key]
166        return default
167
168
169class TestCase(unittest.TestCase):
170    """A class whose instances are single test cases.
171
172    By default, the test code itself should be placed in a method named
173    'runTest'.
174
175    If the fixture may be used for many test cases, create as
176    many test methods as are needed. When instantiating such a TestCase
177    subclass, specify in the constructor arguments the name of the test method
178    that the instance is to execute.
179
180    Test authors should subclass TestCase for their own tests. Construction
181    and deconstruction of the test's environment ('fixture') can be
182    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
183
184    If it is necessary to override the __init__ method, the base class
185    __init__ method must always be called. It is important that subclasses
186    should not change the signature of their __init__ method, since instances
187    of the classes are instantiated automatically by parts of the framework
188    in order to be run.
189    """
190
191    # This attribute determines which exception will be raised when
192    # the instance's assertion methods fail; test methods raising this
193    # exception will be deemed to have 'failed' rather than 'errored'
194
195    failureException = AssertionError
196
197    # This attribute sets the maximum length of a diff in failure messages
198    # by assert methods using difflib. It is looked up as an instance attribute
199    # so can be configured by individual tests if required.
200
201    maxDiff = 80*8
202
203    # This attribute determines whether long messages (including repr of
204    # objects used in assert methods) will be printed on failure in *addition*
205    # to any explicit message passed.
206
207    longMessage = True
208
209    # Attribute used by TestSuite for classSetUp
210
211    _classSetupFailed = False
212
213    def __init__(self, methodName='runTest'):
214        """Create an instance of the class that will use the named test
215           method when executed. Raises a ValueError if the instance does
216           not have a method with the specified name.
217        """
218        self._testMethodName = methodName
219        self._resultForDoCleanups = None
220        try:
221            testMethod = getattr(self, methodName)
222        except AttributeError:
223            raise ValueError("no such test method in %s: %s" % \
224                  (self.__class__, methodName))
225        self._testMethodDoc = testMethod.__doc__
226        self._cleanups = []
227
228        # Map types to custom assertEqual functions that will compare
229        # instances of said type in more detail to generate a more useful
230        # error message.
231        self._type_equality_funcs = _TypeEqualityDict(self)
232        self.addTypeEqualityFunc(dict, 'assertDictEqual')
233        self.addTypeEqualityFunc(list, 'assertListEqual')
234        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
235        self.addTypeEqualityFunc(set, 'assertSetEqual')
236        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
237        self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
238
239    def addTypeEqualityFunc(self, typeobj, function):
240        """Add a type specific assertEqual style function to compare a type.
241
242        This method is for use by TestCase subclasses that need to register
243        their own type equality functions to provide nicer error messages.
244
245        Args:
246            typeobj: The data type to call this function on when both values
247                    are of the same type in assertEqual().
248            function: The callable taking two arguments and an optional
249                    msg= argument that raises self.failureException with a
250                    useful error message when the two arguments are not equal.
251        """
252        self._type_equality_funcs[typeobj] = function
253
254    def addCleanup(self, function, *args, **kwargs):
255        """Add a function, with arguments, to be called when the test is
256        completed. Functions added are called on a LIFO basis and are
257        called after tearDown on test failure or success.
258
259        Cleanup items are called even if setUp fails (unlike tearDown)."""
260        self._cleanups.append((function, args, kwargs))
261
262    def setUp(self):
263        "Hook method for setting up the test fixture before exercising it."
264
265    @classmethod
266    def setUpClass(cls):
267        "Hook method for setting up class fixture before running tests in the class."
268
269    @classmethod
270    def tearDownClass(cls):
271        "Hook method for deconstructing the class fixture after running all tests in the class."
272
273    def tearDown(self):
274        "Hook method for deconstructing the test fixture after testing it."
275
276    def countTestCases(self):
277        return 1
278
279    def defaultTestResult(self):
280        return result.TestResult()
281
282    def shortDescription(self):
283        """Returns a one-line description of the test, or None if no
284        description has been provided.
285
286        The default implementation of this method returns the first line of
287        the specified test method's docstring.
288        """
289        doc = self._testMethodDoc
290        return doc and doc.split("\n")[0].strip() or None
291
292
293    def id(self):
294        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
295
296    def __eq__(self, other):
297        if type(self) is not type(other):
298            return NotImplemented
299
300        return self._testMethodName == other._testMethodName
301
302    def __ne__(self, other):
303        return not self == other
304
305    def __hash__(self):
306        return hash((type(self), self._testMethodName))
307
308    def __str__(self):
309        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
310
311    def __repr__(self):
312        return "<%s testMethod=%s>" % \
313               (strclass(self.__class__), self._testMethodName)
314
315    def _addSkip(self, result, reason):
316        addSkip = getattr(result, 'addSkip', None)
317        if addSkip is not None:
318            addSkip(self, reason)
319        else:
320            warnings.warn("Use of a TestResult without an addSkip method is deprecated",
321                          DeprecationWarning, 2)
322            result.addSuccess(self)
323
324    def run(self, result=None):
325        orig_result = result
326        if result is None:
327            result = self.defaultTestResult()
328            startTestRun = getattr(result, 'startTestRun', None)
329            if startTestRun is not None:
330                startTestRun()
331
332        self._resultForDoCleanups = result
333        result.startTest(self)
334
335        testMethod = getattr(self, self._testMethodName)
336
337        if (getattr(self.__class__, "__unittest_skip__", False) or
338            getattr(testMethod, "__unittest_skip__", False)):
339            # If the class or method was skipped.
340            try:
341                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
342                            or getattr(testMethod, '__unittest_skip_why__', ''))
343                self._addSkip(result, skip_why)
344            finally:
345                result.stopTest(self)
346            return
347        try:
348            success = False
349            try:
350                self.setUp()
351            except SkipTest, e:
352                self._addSkip(result, str(e))
353            except Exception:
354                result.addError(self, sys.exc_info())
355            else:
356                try:
357                    testMethod()
358                except self.failureException:
359                    result.addFailure(self, sys.exc_info())
360                except _ExpectedFailure, e:
361                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
362                    if addExpectedFailure is not None:
363                        addExpectedFailure(self, e.exc_info, e.bugnumber)
364                    else:
365                        warnings.warn("Use of a TestResult without an addExpectedFailure method is deprecated",
366                                      DeprecationWarning)
367                        result.addSuccess(self)
368                except _UnexpectedSuccess, x:
369                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
370                    if addUnexpectedSuccess is not None:
371                        addUnexpectedSuccess(self, x.bugnumber)
372                    else:
373                        warnings.warn("Use of a TestResult without an addUnexpectedSuccess method is deprecated",
374                                      DeprecationWarning)
375                        result.addFailure(self, sys.exc_info())
376                except SkipTest, e:
377                    self._addSkip(result, str(e))
378                except Exception:
379                    result.addError(self, sys.exc_info())
380                else:
381                    success = True
382
383                try:
384                    self.tearDown()
385                except Exception:
386                    result.addError(self, sys.exc_info())
387                    success = False
388
389            cleanUpSuccess = self.doCleanups()
390            success = success and cleanUpSuccess
391            if success:
392                result.addSuccess(self)
393        finally:
394            result.stopTest(self)
395            if orig_result is None:
396                stopTestRun = getattr(result, 'stopTestRun', None)
397                if stopTestRun is not None:
398                    stopTestRun()
399
400    def doCleanups(self):
401        """Execute all cleanup functions. Normally called for you after
402        tearDown."""
403        result = self._resultForDoCleanups
404        ok = True
405        while self._cleanups:
406            function, args, kwargs = self._cleanups.pop(-1)
407            try:
408                function(*args, **kwargs)
409            except Exception:
410                ok = False
411                result.addError(self, sys.exc_info())
412        return ok
413
414    def __call__(self, *args, **kwds):
415        return self.run(*args, **kwds)
416
417    def debug(self):
418        """Run the test without collecting errors in a TestResult"""
419        self.setUp()
420        getattr(self, self._testMethodName)()
421        self.tearDown()
422        while self._cleanups:
423            function, args, kwargs = self._cleanups.pop(-1)
424            function(*args, **kwargs)
425
426    def skipTest(self, reason):
427        """Skip this test."""
428        raise SkipTest(reason)
429
430    def fail(self, msg=None):
431        """Fail immediately, with the given message."""
432        raise self.failureException(msg)
433
434    def assertFalse(self, expr, msg=None):
435        "Fail the test if the expression is true."
436        if expr:
437            msg = self._formatMessage(msg, "%s is not False" % safe_repr(expr))
438            raise self.failureException(msg)
439
440    def assertTrue(self, expr, msg=None):
441        """Fail the test unless the expression is true."""
442        if not expr:
443            msg = self._formatMessage(msg, "%s is not True" % safe_repr(expr))
444            raise self.failureException(msg)
445
446    def _formatMessage(self, msg, standardMsg):
447        """Honour the longMessage attribute when generating failure messages.
448        If longMessage is False this means:
449        * Use only an explicit message if it is provided
450        * Otherwise use the standard message for the assert
451
452        If longMessage is True:
453        * Use the standard message
454        * If an explicit message is provided, plus ' : ' and the explicit message
455        """
456        if not self.longMessage:
457            return msg or standardMsg
458        if msg is None:
459            return standardMsg
460        try:
461            return '%s : %s' % (standardMsg, msg)
462        except UnicodeDecodeError:
463            return '%s : %s' % (safe_str(standardMsg), safe_str(msg))
464
465
466    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
467        """Fail unless an exception of class excClass is thrown
468           by callableObj when invoked with arguments args and keyword
469           arguments kwargs. If a different type of exception is
470           thrown, it will not be caught, and the test case will be
471           deemed to have suffered an error, exactly as for an
472           unexpected exception.
473
474           If called with callableObj omitted or None, will return a
475           context object used like this::
476
477                with self.assertRaises(SomeException):
478                    do_something()
479
480           The context manager keeps a reference to the exception as
481           the 'exception' attribute. This allows you to inspect the
482           exception after the assertion::
483
484               with self.assertRaises(SomeException) as cm:
485                   do_something()
486               the_exception = cm.exception
487               self.assertEqual(the_exception.error_code, 3)
488        """
489        if callableObj is None:
490            return _AssertRaisesContext(excClass, self)
491        try:
492            callableObj(*args, **kwargs)
493        except excClass:
494            return
495
496        if hasattr(excClass,'__name__'):
497            excName = excClass.__name__
498        else:
499            excName = str(excClass)
500        raise self.failureException, "%s not raised" % excName
501
502    def _getAssertEqualityFunc(self, first, second):
503        """Get a detailed comparison function for the types of the two args.
504
505        Returns: A callable accepting (first, second, msg=None) that will
506        raise a failure exception if first != second with a useful human
507        readable error message for those types.
508        """
509        #
510        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
511        # and vice versa.  I opted for the conservative approach in case
512        # subclasses are not intended to be compared in detail to their super
513        # class instances using a type equality func.  This means testing
514        # subtypes won't automagically use the detailed comparison.  Callers
515        # should use their type specific assertSpamEqual method to compare
516        # subclasses if the detailed comparison is desired and appropriate.
517        # See the discussion in http://bugs.python.org/issue2578.
518        #
519        if type(first) is type(second):
520            asserter = self._type_equality_funcs.get(type(first))
521            if asserter is not None:
522                return asserter
523
524        return self._baseAssertEqual
525
526    def _baseAssertEqual(self, first, second, msg=None):
527        """The default assertEqual implementation, not type specific."""
528        if not first == second:
529            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
530            msg = self._formatMessage(msg, standardMsg)
531            raise self.failureException(msg)
532
533    def assertEqual(self, first, second, msg=None):
534        """Fail if the two objects are unequal as determined by the '=='
535           operator.
536        """
537        assertion_func = self._getAssertEqualityFunc(first, second)
538        assertion_func(first, second, msg=msg)
539
540    def assertNotEqual(self, first, second, msg=None):
541        """Fail if the two objects are equal as determined by the '=='
542           operator.
543        """
544        if not first != second:
545            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
546                                                           safe_repr(second)))
547            raise self.failureException(msg)
548
549    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
550        """Fail if the two objects are unequal as determined by their
551           difference rounded to the given number of decimal places
552           (default 7) and comparing to zero, or by comparing that the
553           between the two objects is more than the given delta.
554
555           Note that decimal places (from zero) are usually not the same
556           as significant digits (measured from the most signficant digit).
557
558           If the two objects compare equal then they will automatically
559           compare almost equal.
560        """
561        if first == second:
562            # shortcut
563            return
564        if delta is not None and places is not None:
565            raise TypeError("specify delta or places not both")
566
567        if delta is not None:
568            if abs(first - second) <= delta:
569                return
570
571            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
572                                                        safe_repr(second),
573                                                        safe_repr(delta))
574        else:
575            if places is None:
576                places = 7
577
578            if round(abs(second-first), places) == 0:
579                return
580
581            standardMsg = '%s != %s within %r places' % (safe_repr(first),
582                                                          safe_repr(second),
583                                                          places)
584        msg = self._formatMessage(msg, standardMsg)
585        raise self.failureException(msg)
586
587    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
588        """Fail if the two objects are equal as determined by their
589           difference rounded to the given number of decimal places
590           (default 7) and comparing to zero, or by comparing that the
591           between the two objects is less than the given delta.
592
593           Note that decimal places (from zero) are usually not the same
594           as significant digits (measured from the most signficant digit).
595
596           Objects that are equal automatically fail.
597        """
598        if delta is not None and places is not None:
599            raise TypeError("specify delta or places not both")
600        if delta is not None:
601            if not (first == second) and abs(first - second) > delta:
602                return
603            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
604                                                        safe_repr(second),
605                                                        safe_repr(delta))
606        else:
607            if places is None:
608                places = 7
609            if not (first == second) and round(abs(second-first), places) != 0:
610                return
611            standardMsg = '%s == %s within %r places' % (safe_repr(first),
612                                                         safe_repr(second),
613                                                         places)
614
615        msg = self._formatMessage(msg, standardMsg)
616        raise self.failureException(msg)
617
618    # Synonyms for assertion methods
619
620    # The plurals are undocumented.  Keep them that way to discourage use.
621    # Do not add more.  Do not remove.
622    # Going through a deprecation cycle on these would annoy many people.
623    assertEquals = assertEqual
624    assertNotEquals = assertNotEqual
625    assertAlmostEquals = assertAlmostEqual
626    assertNotAlmostEquals = assertNotAlmostEqual
627    assert_ = assertTrue
628
629    # These fail* assertion method names are pending deprecation and will
630    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
631    def _deprecate(original_func):
632        def deprecated_func(*args, **kwargs):
633            warnings.warn(
634                ('Please use %s instead.' % original_func.__name__),
635                PendingDeprecationWarning, 2)
636            return original_func(*args, **kwargs)
637        return deprecated_func
638
639    failUnlessEqual = _deprecate(assertEqual)
640    failIfEqual = _deprecate(assertNotEqual)
641    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
642    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
643    failUnless = _deprecate(assertTrue)
644    failUnlessRaises = _deprecate(assertRaises)
645    failIf = _deprecate(assertFalse)
646
647    def assertSequenceEqual(self, seq1, seq2,
648                            msg=None, seq_type=None, max_diff=80*8):
649        """An equality assertion for ordered sequences (like lists and tuples).
650
651        For the purposes of this function, a valid ordered sequence type is one
652        which can be indexed, has a length, and has an equality operator.
653
654        Args:
655            seq1: The first sequence to compare.
656            seq2: The second sequence to compare.
657            seq_type: The expected datatype of the sequences, or None if no
658                    datatype should be enforced.
659            msg: Optional message to use on failure instead of a list of
660                    differences.
661            max_diff: Maximum size off the diff, larger diffs are not shown
662        """
663        if seq_type is not None:
664            seq_type_name = seq_type.__name__
665            if not isinstance(seq1, seq_type):
666                raise self.failureException('First sequence is not a %s: %s'
667                                            % (seq_type_name, safe_repr(seq1)))
668            if not isinstance(seq2, seq_type):
669                raise self.failureException('Second sequence is not a %s: %s'
670                                            % (seq_type_name, safe_repr(seq2)))
671        else:
672            seq_type_name = "sequence"
673
674        differing = None
675        try:
676            len1 = len(seq1)
677        except (TypeError, NotImplementedError):
678            differing = 'First %s has no length.    Non-sequence?' % (
679                    seq_type_name)
680
681        if differing is None:
682            try:
683                len2 = len(seq2)
684            except (TypeError, NotImplementedError):
685                differing = 'Second %s has no length.    Non-sequence?' % (
686                        seq_type_name)
687
688        if differing is None:
689            if seq1 == seq2:
690                return
691
692            seq1_repr = repr(seq1)
693            seq2_repr = repr(seq2)
694            if len(seq1_repr) > 30:
695                seq1_repr = seq1_repr[:30] + '...'
696            if len(seq2_repr) > 30:
697                seq2_repr = seq2_repr[:30] + '...'
698            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
699            differing = '%ss differ: %s != %s\n' % elements
700
701            for i in xrange(min(len1, len2)):
702                try:
703                    item1 = seq1[i]
704                except (TypeError, IndexError, NotImplementedError):
705                    differing += ('\nUnable to index element %d of first %s\n' %
706                                 (i, seq_type_name))
707                    break
708
709                try:
710                    item2 = seq2[i]
711                except (TypeError, IndexError, NotImplementedError):
712                    differing += ('\nUnable to index element %d of second %s\n' %
713                                 (i, seq_type_name))
714                    break
715
716                if item1 != item2:
717                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
718                                 (i, item1, item2))
719                    break
720            else:
721                if (len1 == len2 and seq_type is None and
722                    type(seq1) != type(seq2)):
723                    # The sequences are the same, but have differing types.
724                    return
725
726            if len1 > len2:
727                differing += ('\nFirst %s contains %d additional '
728                             'elements.\n' % (seq_type_name, len1 - len2))
729                try:
730                    differing += ('First extra element %d:\n%s\n' %
731                                  (len2, seq1[len2]))
732                except (TypeError, IndexError, NotImplementedError):
733                    differing += ('Unable to index element %d '
734                                  'of first %s\n' % (len2, seq_type_name))
735            elif len1 < len2:
736                differing += ('\nSecond %s contains %d additional '
737                             'elements.\n' % (seq_type_name, len2 - len1))
738                try:
739                    differing += ('First extra element %d:\n%s\n' %
740                                  (len1, seq2[len1]))
741                except (TypeError, IndexError, NotImplementedError):
742                    differing += ('Unable to index element %d '
743                                  'of second %s\n' % (len1, seq_type_name))
744        standardMsg = differing
745        diffMsg = '\n' + '\n'.join(
746            difflib.ndiff(pprint.pformat(seq1).splitlines(),
747                          pprint.pformat(seq2).splitlines()))
748
749        standardMsg = self._truncateMessage(standardMsg, diffMsg)
750        msg = self._formatMessage(msg, standardMsg)
751        self.fail(msg)
752
753    def _truncateMessage(self, message, diff):
754        max_diff = self.maxDiff
755        if max_diff is None or len(diff) <= max_diff:
756            return message + diff
757        return message + (DIFF_OMITTED % len(diff))
758
759    def assertListEqual(self, list1, list2, msg=None):
760        """A list-specific equality assertion.
761
762        Args:
763            list1: The first list to compare.
764            list2: The second list to compare.
765            msg: Optional message to use on failure instead of a list of
766                    differences.
767
768        """
769        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
770
771    def assertTupleEqual(self, tuple1, tuple2, msg=None):
772        """A tuple-specific equality assertion.
773
774        Args:
775            tuple1: The first tuple to compare.
776            tuple2: The second tuple to compare.
777            msg: Optional message to use on failure instead of a list of
778                    differences.
779        """
780        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
781
782    def assertSetEqual(self, set1, set2, msg=None):
783        """A set-specific equality assertion.
784
785        Args:
786            set1: The first set to compare.
787            set2: The second set to compare.
788            msg: Optional message to use on failure instead of a list of
789                    differences.
790
791        assertSetEqual uses ducktyping to support
792        different types of sets, and is optimized for sets specifically
793        (parameters must support a difference method).
794        """
795        try:
796            difference1 = set1.difference(set2)
797        except TypeError, e:
798            self.fail('invalid type when attempting set difference: %s' % e)
799        except AttributeError, e:
800            self.fail('first argument does not support set difference: %s' % e)
801
802        try:
803            difference2 = set2.difference(set1)
804        except TypeError, e:
805            self.fail('invalid type when attempting set difference: %s' % e)
806        except AttributeError, e:
807            self.fail('second argument does not support set difference: %s' % e)
808
809        if not (difference1 or difference2):
810            return
811
812        lines = []
813        if difference1:
814            lines.append('Items in the first set but not the second:')
815            for item in difference1:
816                lines.append(repr(item))
817        if difference2:
818            lines.append('Items in the second set but not the first:')
819            for item in difference2:
820                lines.append(repr(item))
821
822        standardMsg = '\n'.join(lines)
823        self.fail(self._formatMessage(msg, standardMsg))
824
825    def assertIn(self, member, container, msg=None):
826        """Just like self.assertTrue(a in b), but with a nicer default message."""
827        if member not in container:
828            standardMsg = '%s not found in %s' % (safe_repr(member),
829                                                   safe_repr(container))
830            self.fail(self._formatMessage(msg, standardMsg))
831
832    def assertNotIn(self, member, container, msg=None):
833        """Just like self.assertTrue(a not in b), but with a nicer default message."""
834        if member in container:
835            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
836                                                            safe_repr(container))
837            self.fail(self._formatMessage(msg, standardMsg))
838
839    def assertIs(self, expr1, expr2, msg=None):
840        """Just like self.assertTrue(a is b), but with a nicer default message."""
841        if expr1 is not expr2:
842            standardMsg = '%s is not %s' % (safe_repr(expr1), safe_repr(expr2))
843            self.fail(self._formatMessage(msg, standardMsg))
844
845    def assertIsNot(self, expr1, expr2, msg=None):
846        """Just like self.assertTrue(a is not b), but with a nicer default message."""
847        if expr1 is expr2:
848            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
849            self.fail(self._formatMessage(msg, standardMsg))
850
851    def assertDictEqual(self, d1, d2, msg=None):
852        self.assert_(isinstance(d1, dict), 'First argument is not a dictionary')
853        self.assert_(isinstance(d2, dict), 'Second argument is not a dictionary')
854
855        if d1 != d2:
856            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
857            diff = ('\n' + '\n'.join(difflib.ndiff(
858                           pprint.pformat(d1).splitlines(),
859                           pprint.pformat(d2).splitlines())))
860            standardMsg = self._truncateMessage(standardMsg, diff)
861            self.fail(self._formatMessage(msg, standardMsg))
862
863    def assertDictContainsSubset(self, expected, actual, msg=None):
864        """Checks whether actual is a superset of expected."""
865        missing = []
866        mismatched = []
867        for key, value in expected.iteritems():
868            if key not in actual:
869                missing.append(key)
870            elif value != actual[key]:
871                mismatched.append('%s, expected: %s, actual: %s' %
872                                  (safe_repr(key), safe_repr(value),
873                                   safe_repr(actual[key])))
874
875        if not (missing or mismatched):
876            return
877
878        standardMsg = ''
879        if missing:
880            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
881                                                    missing)
882        if mismatched:
883            if standardMsg:
884                standardMsg += '; '
885            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
886
887        self.fail(self._formatMessage(msg, standardMsg))
888
889    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
890        """An unordered sequence specific comparison. It asserts that
891        expected_seq and actual_seq contain the same elements. It is
892        the equivalent of::
893
894            self.assertEqual(sorted(expected_seq), sorted(actual_seq))
895
896        Raises with an error message listing which elements of expected_seq
897        are missing from actual_seq and vice versa if any.
898
899        Asserts that each element has the same count in both sequences.
900        Example:
901            - [0, 1, 1] and [1, 0, 1] compare equal.
902            - [0, 0, 1] and [0, 1] compare unequal.
903        """
904        try:
905            expected = sorted(expected_seq)
906            actual = sorted(actual_seq)
907        except TypeError:
908            # Unsortable items (example: set(), complex(), ...)
909            expected = list(expected_seq)
910            actual = list(actual_seq)
911            missing, unexpected = unorderable_list_difference(
912                expected, actual, ignore_duplicate=False
913            )
914        else:
915            return self.assertSequenceEqual(expected, actual, msg=msg)
916
917        errors = []
918        if missing:
919            errors.append('Expected, but missing:\n    %s' %
920                           safe_repr(missing))
921        if unexpected:
922            errors.append('Unexpected, but present:\n    %s' %
923                           safe_repr(unexpected))
924        if errors:
925            standardMsg = '\n'.join(errors)
926            self.fail(self._formatMessage(msg, standardMsg))
927
928    def assertMultiLineEqual(self, first, second, msg=None):
929        """Assert that two multi-line strings are equal."""
930        self.assert_(isinstance(first, basestring), (
931                'First argument is not a string'))
932        self.assert_(isinstance(second, basestring), (
933                'Second argument is not a string'))
934
935        if first != second:
936            standardMsg = '%s != %s' % (safe_repr(first, True), safe_repr(second, True))
937            diff = '\n' + ''.join(difflib.ndiff(first.splitlines(True),
938                                                       second.splitlines(True)))
939            standardMsg = self._truncateMessage(standardMsg, diff)
940            self.fail(self._formatMessage(msg, standardMsg))
941
942    def assertLess(self, a, b, msg=None):
943        """Just like self.assertTrue(a < b), but with a nicer default message."""
944        if not a < b:
945            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
946            self.fail(self._formatMessage(msg, standardMsg))
947
948    def assertLessEqual(self, a, b, msg=None):
949        """Just like self.assertTrue(a <= b), but with a nicer default message."""
950        if not a <= b:
951            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
952            self.fail(self._formatMessage(msg, standardMsg))
953
954    def assertGreater(self, a, b, msg=None):
955        """Just like self.assertTrue(a > b), but with a nicer default message."""
956        if not a > b:
957            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
958            self.fail(self._formatMessage(msg, standardMsg))
959
960    def assertGreaterEqual(self, a, b, msg=None):
961        """Just like self.assertTrue(a >= b), but with a nicer default message."""
962        if not a >= b:
963            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
964            self.fail(self._formatMessage(msg, standardMsg))
965
966    def assertIsNone(self, obj, msg=None):
967        """Same as self.assertTrue(obj is None), with a nicer default message."""
968        if obj is not None:
969            standardMsg = '%s is not None' % (safe_repr(obj),)
970            self.fail(self._formatMessage(msg, standardMsg))
971
972    def assertIsNotNone(self, obj, msg=None):
973        """Included for symmetry with assertIsNone."""
974        if obj is None:
975            standardMsg = 'unexpectedly None'
976            self.fail(self._formatMessage(msg, standardMsg))
977
978    def assertIsInstance(self, obj, cls, msg=None):
979        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
980        default message."""
981        if not isinstance(obj, cls):
982            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
983            self.fail(self._formatMessage(msg, standardMsg))
984
985    def assertNotIsInstance(self, obj, cls, msg=None):
986        """Included for symmetry with assertIsInstance."""
987        if isinstance(obj, cls):
988            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
989            self.fail(self._formatMessage(msg, standardMsg))
990
991    def assertRaisesRegexp(self, expected_exception, expected_regexp,
992                           callable_obj=None, *args, **kwargs):
993        """Asserts that the message in a raised exception matches a regexp.
994
995        Args:
996            expected_exception: Exception class expected to be raised.
997            expected_regexp: Regexp (re pattern object or string) expected
998                    to be found in error message.
999            callable_obj: Function to be called.
1000            args: Extra args.
1001            kwargs: Extra kwargs.
1002        """
1003        if callable_obj is None:
1004            return _AssertRaisesContext(expected_exception, self, expected_regexp)
1005        try:
1006            callable_obj(*args, **kwargs)
1007        except expected_exception, exc_value:
1008            if isinstance(expected_regexp, basestring):
1009                expected_regexp = re.compile(expected_regexp)
1010            if not expected_regexp.search(str(exc_value)):
1011                raise self.failureException('"%s" does not match "%s"' %
1012                         (expected_regexp.pattern, str(exc_value)))
1013        else:
1014            if hasattr(expected_exception, '__name__'):
1015                excName = expected_exception.__name__
1016            else:
1017                excName = str(expected_exception)
1018            raise self.failureException, "%s not raised" % excName
1019
1020
1021    def assertRegexpMatches(self, text, expected_regexp, msg=None):
1022        """Fail the test unless the text matches the regular expression."""
1023        if isinstance(expected_regexp, basestring):
1024            expected_regexp = re.compile(expected_regexp)
1025        if not expected_regexp.search(text):
1026            msg = msg or "Regexp didn't match"
1027            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1028            raise self.failureException(msg)
1029
1030    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1031        """Fail the test if the text matches the regular expression."""
1032        if isinstance(unexpected_regexp, basestring):
1033            unexpected_regexp = re.compile(unexpected_regexp)
1034        match = unexpected_regexp.search(text)
1035        if match:
1036            msg = msg or "Regexp matched"
1037            msg = '%s: %r matches %r in %r' % (msg,
1038                                               text[match.start():match.end()],
1039                                               unexpected_regexp.pattern,
1040                                               text)
1041            raise self.failureException(msg)
1042
1043class FunctionTestCase(TestCase):
1044    """A test case that wraps a test function.
1045
1046    This is useful for slipping pre-existing test functions into the
1047    unittest framework. Optionally, set-up and tidy-up functions can be
1048    supplied. As with TestCase, the tidy-up ('tearDown') function will
1049    always be called if the set-up ('setUp') function ran successfully.
1050    """
1051
1052    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1053        super(FunctionTestCase, self).__init__()
1054        self._setUpFunc = setUp
1055        self._tearDownFunc = tearDown
1056        self._testFunc = testFunc
1057        self._description = description
1058
1059    def setUp(self):
1060        if self._setUpFunc is not None:
1061            self._setUpFunc()
1062
1063    def tearDown(self):
1064        if self._tearDownFunc is not None:
1065            self._tearDownFunc()
1066
1067    def runTest(self):
1068        self._testFunc()
1069
1070    def id(self):
1071        return self._testFunc.__name__
1072
1073    def __eq__(self, other):
1074        if not isinstance(other, self.__class__):
1075            return NotImplemented
1076
1077        return self._setUpFunc == other._setUpFunc and \
1078               self._tearDownFunc == other._tearDownFunc and \
1079               self._testFunc == other._testFunc and \
1080               self._description == other._description
1081
1082    def __ne__(self, other):
1083        return not self == other
1084
1085    def __hash__(self):
1086        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1087                     self._testFunc, self._description))
1088
1089    def __str__(self):
1090        return "%s (%s)" % (strclass(self.__class__),
1091                            self._testFunc.__name__)
1092
1093    def __repr__(self):
1094        return "<%s testFunc=%s>" % (strclass(self.__class__),
1095                                     self._testFunc)
1096
1097    def shortDescription(self):
1098        if self._description is not None:
1099            return self._description
1100        doc = self._testFunc.__doc__
1101        return doc and doc.split("\n")[0].strip() or None
1102