1"""Test suite for statistics module, including helper NumericTestCase and
2approx_equal function.
3
4"""
5
6import collections
7import decimal
8import doctest
9import math
10import random
11import sys
12import unittest
13
14from decimal import Decimal
15from fractions import Fraction
16
17
18# Module to be tested.
19import statistics
20
21
22# === Helper functions and class ===
23
24def sign(x):
25    """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
26    return math.copysign(1, x)
27
28def _nan_equal(a, b):
29    """Return True if a and b are both the same kind of NAN.
30
31    >>> _nan_equal(Decimal('NAN'), Decimal('NAN'))
32    True
33    >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN'))
34    True
35    >>> _nan_equal(Decimal('NAN'), Decimal('sNAN'))
36    False
37    >>> _nan_equal(Decimal(42), Decimal('NAN'))
38    False
39
40    >>> _nan_equal(float('NAN'), float('NAN'))
41    True
42    >>> _nan_equal(float('NAN'), 0.5)
43    False
44
45    >>> _nan_equal(float('NAN'), Decimal('NAN'))
46    False
47
48    NAN payloads are not compared.
49    """
50    if type(a) is not type(b):
51        return False
52    if isinstance(a, float):
53        return math.isnan(a) and math.isnan(b)
54    aexp = a.as_tuple()[2]
55    bexp = b.as_tuple()[2]
56    return (aexp == bexp) and (aexp in ('n', 'N'))  # Both NAN or both sNAN.
57
58
59def _calc_errors(actual, expected):
60    """Return the absolute and relative errors between two numbers.
61
62    >>> _calc_errors(100, 75)
63    (25, 0.25)
64    >>> _calc_errors(100, 100)
65    (0, 0.0)
66
67    Returns the (absolute error, relative error) between the two arguments.
68    """
69    base = max(abs(actual), abs(expected))
70    abs_err = abs(actual - expected)
71    rel_err = abs_err/base if base else float('inf')
72    return (abs_err, rel_err)
73
74
75def approx_equal(x, y, tol=1e-12, rel=1e-7):
76    """approx_equal(x, y [, tol [, rel]]) => True|False
77
78    Return True if numbers x and y are approximately equal, to within some
79    margin of error, otherwise return False. Numbers which compare equal
80    will also compare approximately equal.
81
82    x is approximately equal to y if the difference between them is less than
83    an absolute error tol or a relative error rel, whichever is bigger.
84
85    If given, both tol and rel must be finite, non-negative numbers. If not
86    given, default values are tol=1e-12 and rel=1e-7.
87
88    >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
89    True
90    >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
91    False
92
93    Absolute error is defined as abs(x-y); if that is less than or equal to
94    tol, x and y are considered approximately equal.
95
96    Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
97    smaller, provided x or y are not zero. If that figure is less than or
98    equal to rel, x and y are considered approximately equal.
99
100    Complex numbers are not directly supported. If you wish to compare to
101    complex numbers, extract their real and imaginary parts and compare them
102    individually.
103
104    NANs always compare unequal, even with themselves. Infinities compare
105    approximately equal if they have the same sign (both positive or both
106    negative). Infinities with different signs compare unequal; so do
107    comparisons of infinities with finite numbers.
108    """
109    if tol < 0 or rel < 0:
110        raise ValueError('error tolerances must be non-negative')
111    # NANs are never equal to anything, approximately or otherwise.
112    if math.isnan(x) or math.isnan(y):
113        return False
114    # Numbers which compare equal also compare approximately equal.
115    if x == y:
116        # This includes the case of two infinities with the same sign.
117        return True
118    if math.isinf(x) or math.isinf(y):
119        # This includes the case of two infinities of opposite sign, or
120        # one infinity and one finite number.
121        return False
122    # Two finite numbers.
123    actual_error = abs(x - y)
124    allowed_error = max(tol, rel*max(abs(x), abs(y)))
125    return actual_error <= allowed_error
126
127
128# This class exists only as somewhere to stick a docstring containing
129# doctests. The following docstring and tests were originally in a separate
130# module. Now that it has been merged in here, I need somewhere to hang the.
131# docstring. Ultimately, this class will die, and the information below will
132# either become redundant, or be moved into more appropriate places.
133class _DoNothing:
134    """
135    When doing numeric work, especially with floats, exact equality is often
136    not what you want. Due to round-off error, it is often a bad idea to try
137    to compare floats with equality. Instead the usual procedure is to test
138    them with some (hopefully small!) allowance for error.
139
140    The ``approx_equal`` function allows you to specify either an absolute
141    error tolerance, or a relative error, or both.
142
143    Absolute error tolerances are simple, but you need to know the magnitude
144    of the quantities being compared:
145
146    >>> approx_equal(12.345, 12.346, tol=1e-3)
147    True
148    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3)  # tol is too small.
149    False
150
151    Relative errors are more suitable when the values you are comparing can
152    vary in magnitude:
153
154    >>> approx_equal(12.345, 12.346, rel=1e-4)
155    True
156    >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
157    True
158
159    but a naive implementation of relative error testing can run into trouble
160    around zero.
161
162    If you supply both an absolute tolerance and a relative error, the
163    comparison succeeds if either individual test succeeds:
164
165    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
166    True
167
168    """
169    pass
170
171
172
173# We prefer this for testing numeric values that may not be exactly equal,
174# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
175
176class NumericTestCase(unittest.TestCase):
177    """Unit test class for numeric work.
178
179    This subclasses TestCase. In addition to the standard method
180    ``TestCase.assertAlmostEqual``,  ``assertApproxEqual`` is provided.
181    """
182    # By default, we expect exact equality, unless overridden.
183    tol = rel = 0
184
185    def assertApproxEqual(
186            self, first, second, tol=None, rel=None, msg=None
187            ):
188        """Test passes if ``first`` and ``second`` are approximately equal.
189
190        This test passes if ``first`` and ``second`` are equal to
191        within ``tol``, an absolute error, or ``rel``, a relative error.
192
193        If either ``tol`` or ``rel`` are None or not given, they default to
194        test attributes of the same name (by default, 0).
195
196        The objects may be either numbers, or sequences of numbers. Sequences
197        are tested element-by-element.
198
199        >>> class MyTest(NumericTestCase):
200        ...     def test_number(self):
201        ...         x = 1.0/6
202        ...         y = sum([x]*6)
203        ...         self.assertApproxEqual(y, 1.0, tol=1e-15)
204        ...     def test_sequence(self):
205        ...         a = [1.001, 1.001e-10, 1.001e10]
206        ...         b = [1.0, 1e-10, 1e10]
207        ...         self.assertApproxEqual(a, b, rel=1e-3)
208        ...
209        >>> import unittest
210        >>> from io import StringIO  # Suppress test runner output.
211        >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
212        >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
213        <unittest.runner.TextTestResult run=2 errors=0 failures=0>
214
215        """
216        if tol is None:
217            tol = self.tol
218        if rel is None:
219            rel = self.rel
220        if (
221                isinstance(first, collections.Sequence) and
222                isinstance(second, collections.Sequence)
223            ):
224            check = self._check_approx_seq
225        else:
226            check = self._check_approx_num
227        check(first, second, tol, rel, msg)
228
229    def _check_approx_seq(self, first, second, tol, rel, msg):
230        if len(first) != len(second):
231            standardMsg = (
232                "sequences differ in length: %d items != %d items"
233                % (len(first), len(second))
234                )
235            msg = self._formatMessage(msg, standardMsg)
236            raise self.failureException(msg)
237        for i, (a,e) in enumerate(zip(first, second)):
238            self._check_approx_num(a, e, tol, rel, msg, i)
239
240    def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
241        if approx_equal(first, second, tol, rel):
242            # Test passes. Return early, we are done.
243            return None
244        # Otherwise we failed.
245        standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
246        msg = self._formatMessage(msg, standardMsg)
247        raise self.failureException(msg)
248
249    @staticmethod
250    def _make_std_err_msg(first, second, tol, rel, idx):
251        # Create the standard error message for approx_equal failures.
252        assert first != second
253        template = (
254            '  %r != %r\n'
255            '  values differ by more than tol=%r and rel=%r\n'
256            '  -> absolute error = %r\n'
257            '  -> relative error = %r'
258            )
259        if idx is not None:
260            header = 'numeric sequences first differ at index %d.\n' % idx
261            template = header + template
262        # Calculate actual errors:
263        abs_err, rel_err = _calc_errors(first, second)
264        return template % (first, second, tol, rel, abs_err, rel_err)
265
266
267# ========================
268# === Test the helpers ===
269# ========================
270
271class TestSign(unittest.TestCase):
272    """Test that the helper function sign() works correctly."""
273    def testZeroes(self):
274        # Test that signed zeroes report their sign correctly.
275        self.assertEqual(sign(0.0), +1)
276        self.assertEqual(sign(-0.0), -1)
277
278
279# --- Tests for approx_equal ---
280
281class ApproxEqualSymmetryTest(unittest.TestCase):
282    # Test symmetry of approx_equal.
283
284    def test_relative_symmetry(self):
285        # Check that approx_equal treats relative error symmetrically.
286        # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
287        # doesn't matter.
288        #
289        #   Note: the reason for this test is that an early version
290        #   of approx_equal was not symmetric. A relative error test
291        #   would pass, or fail, depending on which value was passed
292        #   as the first argument.
293        #
294        args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
295        args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
296        assert len(args1) == len(args2)
297        for a, b in zip(args1, args2):
298            self.do_relative_symmetry(a, b)
299
300    def do_relative_symmetry(self, a, b):
301        a, b = min(a, b), max(a, b)
302        assert a < b
303        delta = b - a  # The absolute difference between the values.
304        rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
305        # Choose an error margin halfway between the two.
306        rel = (rel_err1 + rel_err2)/2
307        # Now see that values a and b compare approx equal regardless of
308        # which is given first.
309        self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
310        self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
311
312    def test_symmetry(self):
313        # Test that approx_equal(a, b) == approx_equal(b, a)
314        args = [-23, -2, 5, 107, 93568]
315        delta = 2
316        for a in args:
317            for type_ in (int, float, Decimal, Fraction):
318                x = type_(a)*100
319                y = x + delta
320                r = abs(delta/max(x, y))
321                # There are five cases to check:
322                # 1) actual error <= tol, <= rel
323                self.do_symmetry_test(x, y, tol=delta, rel=r)
324                self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
325                # 2) actual error > tol, > rel
326                self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
327                # 3) actual error <= tol, > rel
328                self.do_symmetry_test(x, y, tol=delta, rel=r/2)
329                # 4) actual error > tol, <= rel
330                self.do_symmetry_test(x, y, tol=delta-1, rel=r)
331                self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
332                # 5) exact equality test
333                self.do_symmetry_test(x, x, tol=0, rel=0)
334                self.do_symmetry_test(x, y, tol=0, rel=0)
335
336    def do_symmetry_test(self, a, b, tol, rel):
337        template = "approx_equal comparisons don't match for %r"
338        flag1 = approx_equal(a, b, tol, rel)
339        flag2 = approx_equal(b, a, tol, rel)
340        self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
341
342
343class ApproxEqualExactTest(unittest.TestCase):
344    # Test the approx_equal function with exactly equal values.
345    # Equal values should compare as approximately equal.
346    # Test cases for exactly equal values, which should compare approx
347    # equal regardless of the error tolerances given.
348
349    def do_exactly_equal_test(self, x, tol, rel):
350        result = approx_equal(x, x, tol=tol, rel=rel)
351        self.assertTrue(result, 'equality failure for x=%r' % x)
352        result = approx_equal(-x, -x, tol=tol, rel=rel)
353        self.assertTrue(result, 'equality failure for x=%r' % -x)
354
355    def test_exactly_equal_ints(self):
356        # Test that equal int values are exactly equal.
357        for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
358            self.do_exactly_equal_test(n, 0, 0)
359
360    def test_exactly_equal_floats(self):
361        # Test that equal float values are exactly equal.
362        for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
363            self.do_exactly_equal_test(x, 0, 0)
364
365    def test_exactly_equal_fractions(self):
366        # Test that equal Fraction values are exactly equal.
367        F = Fraction
368        for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
369            self.do_exactly_equal_test(f, 0, 0)
370
371    def test_exactly_equal_decimals(self):
372        # Test that equal Decimal values are exactly equal.
373        D = Decimal
374        for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
375            self.do_exactly_equal_test(d, 0, 0)
376
377    def test_exactly_equal_absolute(self):
378        # Test that equal values are exactly equal with an absolute error.
379        for n in [16, 1013, 1372, 1198, 971, 4]:
380            # Test as ints.
381            self.do_exactly_equal_test(n, 0.01, 0)
382            # Test as floats.
383            self.do_exactly_equal_test(n/10, 0.01, 0)
384            # Test as Fractions.
385            f = Fraction(n, 1234)
386            self.do_exactly_equal_test(f, 0.01, 0)
387
388    def test_exactly_equal_absolute_decimals(self):
389        # Test equal Decimal values are exactly equal with an absolute error.
390        self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
391        self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
392
393    def test_exactly_equal_relative(self):
394        # Test that equal values are exactly equal with a relative error.
395        for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
396            self.do_exactly_equal_test(x, 0, 0.01)
397        self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
398
399    def test_exactly_equal_both(self):
400        # Test that equal values are equal when both tol and rel are given.
401        for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
402            self.do_exactly_equal_test(x, 0.1, 0.01)
403        D = Decimal
404        self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
405
406
407class ApproxEqualUnequalTest(unittest.TestCase):
408    # Unequal values should compare unequal with zero error tolerances.
409    # Test cases for unequal values, with exact equality test.
410
411    def do_exactly_unequal_test(self, x):
412        for a in (x, -x):
413            result = approx_equal(a, a+1, tol=0, rel=0)
414            self.assertFalse(result, 'inequality failure for x=%r' % a)
415
416    def test_exactly_unequal_ints(self):
417        # Test unequal int values are unequal with zero error tolerance.
418        for n in [951, 572305, 478, 917, 17240]:
419            self.do_exactly_unequal_test(n)
420
421    def test_exactly_unequal_floats(self):
422        # Test unequal float values are unequal with zero error tolerance.
423        for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
424            self.do_exactly_unequal_test(x)
425
426    def test_exactly_unequal_fractions(self):
427        # Test that unequal Fractions are unequal with zero error tolerance.
428        F = Fraction
429        for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
430            self.do_exactly_unequal_test(f)
431
432    def test_exactly_unequal_decimals(self):
433        # Test that unequal Decimals are unequal with zero error tolerance.
434        for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
435            self.do_exactly_unequal_test(d)
436
437
438class ApproxEqualInexactTest(unittest.TestCase):
439    # Inexact test cases for approx_error.
440    # Test cases when comparing two values that are not exactly equal.
441
442    # === Absolute error tests ===
443
444    def do_approx_equal_abs_test(self, x, delta):
445        template = "Test failure for x={!r}, y={!r}"
446        for y in (x + delta, x - delta):
447            msg = template.format(x, y)
448            self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
449            self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
450
451    def test_approx_equal_absolute_ints(self):
452        # Test approximate equality of ints with an absolute error.
453        for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
454            self.do_approx_equal_abs_test(n, 10)
455            self.do_approx_equal_abs_test(n, 2)
456
457    def test_approx_equal_absolute_floats(self):
458        # Test approximate equality of floats with an absolute error.
459        for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
460            self.do_approx_equal_abs_test(x, 1.5)
461            self.do_approx_equal_abs_test(x, 0.01)
462            self.do_approx_equal_abs_test(x, 0.0001)
463
464    def test_approx_equal_absolute_fractions(self):
465        # Test approximate equality of Fractions with an absolute error.
466        delta = Fraction(1, 29)
467        numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
468        for f in (Fraction(n, 29) for n in numerators):
469            self.do_approx_equal_abs_test(f, delta)
470            self.do_approx_equal_abs_test(f, float(delta))
471
472    def test_approx_equal_absolute_decimals(self):
473        # Test approximate equality of Decimals with an absolute error.
474        delta = Decimal("0.01")
475        for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
476            self.do_approx_equal_abs_test(d, delta)
477            self.do_approx_equal_abs_test(-d, delta)
478
479    def test_cross_zero(self):
480        # Test for the case of the two values having opposite signs.
481        self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
482
483    # === Relative error tests ===
484
485    def do_approx_equal_rel_test(self, x, delta):
486        template = "Test failure for x={!r}, y={!r}"
487        for y in (x*(1+delta), x*(1-delta)):
488            msg = template.format(x, y)
489            self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
490            self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
491
492    def test_approx_equal_relative_ints(self):
493        # Test approximate equality of ints with a relative error.
494        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
495        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
496        # ---
497        self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
498        self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
499        self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
500
501    def test_approx_equal_relative_floats(self):
502        # Test approximate equality of floats with a relative error.
503        for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
504            self.do_approx_equal_rel_test(x, 0.02)
505            self.do_approx_equal_rel_test(x, 0.0001)
506
507    def test_approx_equal_relative_fractions(self):
508        # Test approximate equality of Fractions with a relative error.
509        F = Fraction
510        delta = Fraction(3, 8)
511        for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
512            for d in (delta, float(delta)):
513                self.do_approx_equal_rel_test(f, d)
514                self.do_approx_equal_rel_test(-f, d)
515
516    def test_approx_equal_relative_decimals(self):
517        # Test approximate equality of Decimals with a relative error.
518        for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
519            self.do_approx_equal_rel_test(d, Decimal("0.001"))
520            self.do_approx_equal_rel_test(-d, Decimal("0.05"))
521
522    # === Both absolute and relative error tests ===
523
524    # There are four cases to consider:
525    #   1) actual error <= both absolute and relative error
526    #   2) actual error <= absolute error but > relative error
527    #   3) actual error <= relative error but > absolute error
528    #   4) actual error > both absolute and relative error
529
530    def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
531        check = self.assertTrue if tol_flag else self.assertFalse
532        check(approx_equal(a, b, tol=tol, rel=0))
533        check = self.assertTrue if rel_flag else self.assertFalse
534        check(approx_equal(a, b, tol=0, rel=rel))
535        check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
536        check(approx_equal(a, b, tol=tol, rel=rel))
537
538    def test_approx_equal_both1(self):
539        # Test actual error <= both absolute and relative error.
540        self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
541        self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
542
543    def test_approx_equal_both2(self):
544        # Test actual error <= absolute error but > relative error.
545        self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
546
547    def test_approx_equal_both3(self):
548        # Test actual error <= relative error but > absolute error.
549        self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
550
551    def test_approx_equal_both4(self):
552        # Test actual error > both absolute and relative error.
553        self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
554        self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
555
556
557class ApproxEqualSpecialsTest(unittest.TestCase):
558    # Test approx_equal with NANs and INFs and zeroes.
559
560    def test_inf(self):
561        for type_ in (float, Decimal):
562            inf = type_('inf')
563            self.assertTrue(approx_equal(inf, inf))
564            self.assertTrue(approx_equal(inf, inf, 0, 0))
565            self.assertTrue(approx_equal(inf, inf, 1, 0.01))
566            self.assertTrue(approx_equal(-inf, -inf))
567            self.assertFalse(approx_equal(inf, -inf))
568            self.assertFalse(approx_equal(inf, 1000))
569
570    def test_nan(self):
571        for type_ in (float, Decimal):
572            nan = type_('nan')
573            for other in (nan, type_('inf'), 1000):
574                self.assertFalse(approx_equal(nan, other))
575
576    def test_float_zeroes(self):
577        nzero = math.copysign(0.0, -1)
578        self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
579
580    def test_decimal_zeroes(self):
581        nzero = Decimal("-0.0")
582        self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
583
584
585class TestApproxEqualErrors(unittest.TestCase):
586    # Test error conditions of approx_equal.
587
588    def test_bad_tol(self):
589        # Test negative tol raises.
590        self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
591
592    def test_bad_rel(self):
593        # Test negative rel raises.
594        self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
595
596
597# --- Tests for NumericTestCase ---
598
599# The formatting routine that generates the error messages is complex enough
600# that it too needs testing.
601
602class TestNumericTestCase(unittest.TestCase):
603    # The exact wording of NumericTestCase error messages is *not* guaranteed,
604    # but we need to give them some sort of test to ensure that they are
605    # generated correctly. As a compromise, we look for specific substrings
606    # that are expected to be found even if the overall error message changes.
607
608    def do_test(self, args):
609        actual_msg = NumericTestCase._make_std_err_msg(*args)
610        expected = self.generate_substrings(*args)
611        for substring in expected:
612            self.assertIn(substring, actual_msg)
613
614    def test_numerictestcase_is_testcase(self):
615        # Ensure that NumericTestCase actually is a TestCase.
616        self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
617
618    def test_error_msg_numeric(self):
619        # Test the error message generated for numeric comparisons.
620        args = (2.5, 4.0, 0.5, 0.25, None)
621        self.do_test(args)
622
623    def test_error_msg_sequence(self):
624        # Test the error message generated for sequence comparisons.
625        args = (3.75, 8.25, 1.25, 0.5, 7)
626        self.do_test(args)
627
628    def generate_substrings(self, first, second, tol, rel, idx):
629        """Return substrings we expect to see in error messages."""
630        abs_err, rel_err = _calc_errors(first, second)
631        substrings = [
632                'tol=%r' % tol,
633                'rel=%r' % rel,
634                'absolute error = %r' % abs_err,
635                'relative error = %r' % rel_err,
636                ]
637        if idx is not None:
638            substrings.append('differ at index %d' % idx)
639        return substrings
640
641
642# =======================================
643# === Tests for the statistics module ===
644# =======================================
645
646
647class GlobalsTest(unittest.TestCase):
648    module = statistics
649    expected_metadata = ["__doc__", "__all__"]
650
651    def test_meta(self):
652        # Test for the existence of metadata.
653        for meta in self.expected_metadata:
654            self.assertTrue(hasattr(self.module, meta),
655                            "%s not present" % meta)
656
657    def test_check_all(self):
658        # Check everything in __all__ exists and is public.
659        module = self.module
660        for name in module.__all__:
661            # No private names in __all__:
662            self.assertFalse(name.startswith("_"),
663                             'private name "%s" in __all__' % name)
664            # And anything in __all__ must exist:
665            self.assertTrue(hasattr(module, name),
666                            'missing name "%s" in __all__' % name)
667
668
669class DocTests(unittest.TestCase):
670    @unittest.skipIf(sys.flags.optimize >= 2,
671                     "Docstrings are omitted with -OO and above")
672    def test_doc_tests(self):
673        failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
674        self.assertGreater(tried, 0)
675        self.assertEqual(failed, 0)
676
677class StatisticsErrorTest(unittest.TestCase):
678    def test_has_exception(self):
679        errmsg = (
680                "Expected StatisticsError to be a ValueError, but got a"
681                " subclass of %r instead."
682                )
683        self.assertTrue(hasattr(statistics, 'StatisticsError'))
684        self.assertTrue(
685                issubclass(statistics.StatisticsError, ValueError),
686                errmsg % statistics.StatisticsError.__base__
687                )
688
689
690# === Tests for private utility functions ===
691
692class ExactRatioTest(unittest.TestCase):
693    # Test _exact_ratio utility.
694
695    def test_int(self):
696        for i in (-20, -3, 0, 5, 99, 10**20):
697            self.assertEqual(statistics._exact_ratio(i), (i, 1))
698
699    def test_fraction(self):
700        numerators = (-5, 1, 12, 38)
701        for n in numerators:
702            f = Fraction(n, 37)
703            self.assertEqual(statistics._exact_ratio(f), (n, 37))
704
705    def test_float(self):
706        self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
707        self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
708        data = [random.uniform(-100, 100) for _ in range(100)]
709        for x in data:
710            num, den = statistics._exact_ratio(x)
711            self.assertEqual(x, num/den)
712
713    def test_decimal(self):
714        D = Decimal
715        _exact_ratio = statistics._exact_ratio
716        self.assertEqual(_exact_ratio(D("0.125")), (1, 8))
717        self.assertEqual(_exact_ratio(D("12.345")), (2469, 200))
718        self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50))
719
720    def test_inf(self):
721        INF = float("INF")
722        class MyFloat(float):
723            pass
724        class MyDecimal(Decimal):
725            pass
726        for inf in (INF, -INF):
727            for type_ in (float, MyFloat, Decimal, MyDecimal):
728                x = type_(inf)
729                ratio = statistics._exact_ratio(x)
730                self.assertEqual(ratio, (x, None))
731                self.assertEqual(type(ratio[0]), type_)
732                self.assertTrue(math.isinf(ratio[0]))
733
734    def test_float_nan(self):
735        NAN = float("NAN")
736        class MyFloat(float):
737            pass
738        for nan in (NAN, MyFloat(NAN)):
739            ratio = statistics._exact_ratio(nan)
740            self.assertTrue(math.isnan(ratio[0]))
741            self.assertIs(ratio[1], None)
742            self.assertEqual(type(ratio[0]), type(nan))
743
744    def test_decimal_nan(self):
745        NAN = Decimal("NAN")
746        sNAN = Decimal("sNAN")
747        class MyDecimal(Decimal):
748            pass
749        for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)):
750            ratio = statistics._exact_ratio(nan)
751            self.assertTrue(_nan_equal(ratio[0], nan))
752            self.assertIs(ratio[1], None)
753            self.assertEqual(type(ratio[0]), type(nan))
754
755
756class DecimalToRatioTest(unittest.TestCase):
757    # Test _exact_ratio private function.
758
759    def test_infinity(self):
760        # Test that INFs are handled correctly.
761        inf = Decimal('INF')
762        self.assertEqual(statistics._exact_ratio(inf), (inf, None))
763        self.assertEqual(statistics._exact_ratio(-inf), (-inf, None))
764
765    def test_nan(self):
766        # Test that NANs are handled correctly.
767        for nan in (Decimal('NAN'), Decimal('sNAN')):
768            num, den = statistics._exact_ratio(nan)
769            # Because NANs always compare non-equal, we cannot use assertEqual.
770            # Nor can we use an identity test, as we don't guarantee anything
771            # about the object identity.
772            self.assertTrue(_nan_equal(num, nan))
773            self.assertIs(den, None)
774
775    def test_sign(self):
776        # Test sign is calculated correctly.
777        numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")]
778        for d in numbers:
779            # First test positive decimals.
780            assert d > 0
781            num, den = statistics._exact_ratio(d)
782            self.assertGreaterEqual(num, 0)
783            self.assertGreater(den, 0)
784            # Then test negative decimals.
785            num, den = statistics._exact_ratio(-d)
786            self.assertLessEqual(num, 0)
787            self.assertGreater(den, 0)
788
789    def test_negative_exponent(self):
790        # Test result when the exponent is negative.
791        t = statistics._exact_ratio(Decimal("0.1234"))
792        self.assertEqual(t, (617, 5000))
793
794    def test_positive_exponent(self):
795        # Test results when the exponent is positive.
796        t = statistics._exact_ratio(Decimal("1.234e7"))
797        self.assertEqual(t, (12340000, 1))
798
799    def test_regression_20536(self):
800        # Regression test for issue 20536.
801        # See http://bugs.python.org/issue20536
802        t = statistics._exact_ratio(Decimal("1e2"))
803        self.assertEqual(t, (100, 1))
804        t = statistics._exact_ratio(Decimal("1.47e5"))
805        self.assertEqual(t, (147000, 1))
806
807
808class IsFiniteTest(unittest.TestCase):
809    # Test _isfinite private function.
810
811    def test_finite(self):
812        # Test that finite numbers are recognised as finite.
813        for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")):
814            self.assertTrue(statistics._isfinite(x))
815
816    def test_infinity(self):
817        # Test that INFs are not recognised as finite.
818        for x in (float("inf"), Decimal("inf")):
819            self.assertFalse(statistics._isfinite(x))
820
821    def test_nan(self):
822        # Test that NANs are not recognised as finite.
823        for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")):
824            self.assertFalse(statistics._isfinite(x))
825
826
827class CoerceTest(unittest.TestCase):
828    # Test that private function _coerce correctly deals with types.
829
830    # The coercion rules are currently an implementation detail, although at
831    # some point that should change. The tests and comments here define the
832    # correct implementation.
833
834    # Pre-conditions of _coerce:
835    #
836    #   - The first time _sum calls _coerce, the
837    #   - coerce(T, S) will never be called with bool as the first argument;
838    #     this is a pre-condition, guarded with an assertion.
839
840    #
841    #   - coerce(T, T) will always return T; we assume T is a valid numeric
842    #     type. Violate this assumption at your own risk.
843    #
844    #   - Apart from as above, bool is treated as if it were actually int.
845    #
846    #   - coerce(int, X) and coerce(X, int) return X.
847    #   -
848    def test_bool(self):
849        # bool is somewhat special, due to the pre-condition that it is
850        # never given as the first argument to _coerce, and that it cannot
851        # be subclassed. So we test it specially.
852        for T in (int, float, Fraction, Decimal):
853            self.assertIs(statistics._coerce(T, bool), T)
854            class MyClass(T): pass
855            self.assertIs(statistics._coerce(MyClass, bool), MyClass)
856
857    def assertCoerceTo(self, A, B):
858        """Assert that type A coerces to B."""
859        self.assertIs(statistics._coerce(A, B), B)
860        self.assertIs(statistics._coerce(B, A), B)
861
862    def check_coerce_to(self, A, B):
863        """Checks that type A coerces to B, including subclasses."""
864        # Assert that type A is coerced to B.
865        self.assertCoerceTo(A, B)
866        # Subclasses of A are also coerced to B.
867        class SubclassOfA(A): pass
868        self.assertCoerceTo(SubclassOfA, B)
869        # A, and subclasses of A, are coerced to subclasses of B.
870        class SubclassOfB(B): pass
871        self.assertCoerceTo(A, SubclassOfB)
872        self.assertCoerceTo(SubclassOfA, SubclassOfB)
873
874    def assertCoerceRaises(self, A, B):
875        """Assert that coercing A to B, or vice versa, raises TypeError."""
876        self.assertRaises(TypeError, statistics._coerce, (A, B))
877        self.assertRaises(TypeError, statistics._coerce, (B, A))
878
879    def check_type_coercions(self, T):
880        """Check that type T coerces correctly with subclasses of itself."""
881        assert T is not bool
882        # Coercing a type with itself returns the same type.
883        self.assertIs(statistics._coerce(T, T), T)
884        # Coercing a type with a subclass of itself returns the subclass.
885        class U(T): pass
886        class V(T): pass
887        class W(U): pass
888        for typ in (U, V, W):
889            self.assertCoerceTo(T, typ)
890        self.assertCoerceTo(U, W)
891        # Coercing two subclasses that aren't parent/child is an error.
892        self.assertCoerceRaises(U, V)
893        self.assertCoerceRaises(V, W)
894
895    def test_int(self):
896        # Check that int coerces correctly.
897        self.check_type_coercions(int)
898        for typ in (float, Fraction, Decimal):
899            self.check_coerce_to(int, typ)
900
901    def test_fraction(self):
902        # Check that Fraction coerces correctly.
903        self.check_type_coercions(Fraction)
904        self.check_coerce_to(Fraction, float)
905
906    def test_decimal(self):
907        # Check that Decimal coerces correctly.
908        self.check_type_coercions(Decimal)
909
910    def test_float(self):
911        # Check that float coerces correctly.
912        self.check_type_coercions(float)
913
914    def test_non_numeric_types(self):
915        for bad_type in (str, list, type(None), tuple, dict):
916            for good_type in (int, float, Fraction, Decimal):
917                self.assertCoerceRaises(good_type, bad_type)
918
919    def test_incompatible_types(self):
920        # Test that incompatible types raise.
921        for T in (float, Fraction):
922            class MySubclass(T): pass
923            self.assertCoerceRaises(T, Decimal)
924            self.assertCoerceRaises(MySubclass, Decimal)
925
926
927class ConvertTest(unittest.TestCase):
928    # Test private _convert function.
929
930    def check_exact_equal(self, x, y):
931        """Check that x equals y, and has the same type as well."""
932        self.assertEqual(x, y)
933        self.assertIs(type(x), type(y))
934
935    def test_int(self):
936        # Test conversions to int.
937        x = statistics._convert(Fraction(71), int)
938        self.check_exact_equal(x, 71)
939        class MyInt(int): pass
940        x = statistics._convert(Fraction(17), MyInt)
941        self.check_exact_equal(x, MyInt(17))
942
943    def test_fraction(self):
944        # Test conversions to Fraction.
945        x = statistics._convert(Fraction(95, 99), Fraction)
946        self.check_exact_equal(x, Fraction(95, 99))
947        class MyFraction(Fraction):
948            def __truediv__(self, other):
949                return self.__class__(super().__truediv__(other))
950        x = statistics._convert(Fraction(71, 13), MyFraction)
951        self.check_exact_equal(x, MyFraction(71, 13))
952
953    def test_float(self):
954        # Test conversions to float.
955        x = statistics._convert(Fraction(-1, 2), float)
956        self.check_exact_equal(x, -0.5)
957        class MyFloat(float):
958            def __truediv__(self, other):
959                return self.__class__(super().__truediv__(other))
960        x = statistics._convert(Fraction(9, 8), MyFloat)
961        self.check_exact_equal(x, MyFloat(1.125))
962
963    def test_decimal(self):
964        # Test conversions to Decimal.
965        x = statistics._convert(Fraction(1, 40), Decimal)
966        self.check_exact_equal(x, Decimal("0.025"))
967        class MyDecimal(Decimal):
968            def __truediv__(self, other):
969                return self.__class__(super().__truediv__(other))
970        x = statistics._convert(Fraction(-15, 16), MyDecimal)
971        self.check_exact_equal(x, MyDecimal("-0.9375"))
972
973    def test_inf(self):
974        for INF in (float('inf'), Decimal('inf')):
975            for inf in (INF, -INF):
976                x = statistics._convert(inf, type(inf))
977                self.check_exact_equal(x, inf)
978
979    def test_nan(self):
980        for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')):
981            x = statistics._convert(nan, type(nan))
982            self.assertTrue(_nan_equal(x, nan))
983
984
985class FailNegTest(unittest.TestCase):
986    """Test _fail_neg private function."""
987
988    def test_pass_through(self):
989        # Test that values are passed through unchanged.
990        values = [1, 2.0, Fraction(3), Decimal(4)]
991        new = list(statistics._fail_neg(values))
992        self.assertEqual(values, new)
993
994    def test_negatives_raise(self):
995        # Test that negatives raise an exception.
996        for x in [1, 2.0, Fraction(3), Decimal(4)]:
997            seq = [-x]
998            it = statistics._fail_neg(seq)
999            self.assertRaises(statistics.StatisticsError, next, it)
1000
1001    def test_error_msg(self):
1002        # Test that a given error message is used.
1003        msg = "badness #%d" % random.randint(10000, 99999)
1004        try:
1005            next(statistics._fail_neg([-1], msg))
1006        except statistics.StatisticsError as e:
1007            errmsg = e.args[0]
1008        else:
1009            self.fail("expected exception, but it didn't happen")
1010        self.assertEqual(errmsg, msg)
1011
1012
1013# === Tests for public functions ===
1014
1015class UnivariateCommonMixin:
1016    # Common tests for most univariate functions that take a data argument.
1017
1018    def test_no_args(self):
1019        # Fail if given no arguments.
1020        self.assertRaises(TypeError, self.func)
1021
1022    def test_empty_data(self):
1023        # Fail when the data argument (first argument) is empty.
1024        for empty in ([], (), iter([])):
1025            self.assertRaises(statistics.StatisticsError, self.func, empty)
1026
1027    def prepare_data(self):
1028        """Return int data for various tests."""
1029        data = list(range(10))
1030        while data == sorted(data):
1031            random.shuffle(data)
1032        return data
1033
1034    def test_no_inplace_modifications(self):
1035        # Test that the function does not modify its input data.
1036        data = self.prepare_data()
1037        assert len(data) != 1  # Necessary to avoid infinite loop.
1038        assert data != sorted(data)
1039        saved = data[:]
1040        assert data is not saved
1041        _ = self.func(data)
1042        self.assertListEqual(data, saved, "data has been modified")
1043
1044    def test_order_doesnt_matter(self):
1045        # Test that the order of data points doesn't change the result.
1046
1047        # CAUTION: due to floating point rounding errors, the result actually
1048        # may depend on the order. Consider this test representing an ideal.
1049        # To avoid this test failing, only test with exact values such as ints
1050        # or Fractions.
1051        data = [1, 2, 3, 3, 3, 4, 5, 6]*100
1052        expected = self.func(data)
1053        random.shuffle(data)
1054        actual = self.func(data)
1055        self.assertEqual(expected, actual)
1056
1057    def test_type_of_data_collection(self):
1058        # Test that the type of iterable data doesn't effect the result.
1059        class MyList(list):
1060            pass
1061        class MyTuple(tuple):
1062            pass
1063        def generator(data):
1064            return (obj for obj in data)
1065        data = self.prepare_data()
1066        expected = self.func(data)
1067        for kind in (list, tuple, iter, MyList, MyTuple, generator):
1068            result = self.func(kind(data))
1069            self.assertEqual(result, expected)
1070
1071    def test_range_data(self):
1072        # Test that functions work with range objects.
1073        data = range(20, 50, 3)
1074        expected = self.func(list(data))
1075        self.assertEqual(self.func(data), expected)
1076
1077    def test_bad_arg_types(self):
1078        # Test that function raises when given data of the wrong type.
1079
1080        # Don't roll the following into a loop like this:
1081        #   for bad in list_of_bad:
1082        #       self.check_for_type_error(bad)
1083        #
1084        # Since assertRaises doesn't show the arguments that caused the test
1085        # failure, it is very difficult to debug these test failures when the
1086        # following are in a loop.
1087        self.check_for_type_error(None)
1088        self.check_for_type_error(23)
1089        self.check_for_type_error(42.0)
1090        self.check_for_type_error(object())
1091
1092    def check_for_type_error(self, *args):
1093        self.assertRaises(TypeError, self.func, *args)
1094
1095    def test_type_of_data_element(self):
1096        # Check the type of data elements doesn't affect the numeric result.
1097        # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
1098        # because it checks the numeric result by equality, but not by type.
1099        class MyFloat(float):
1100            def __truediv__(self, other):
1101                return type(self)(super().__truediv__(other))
1102            def __add__(self, other):
1103                return type(self)(super().__add__(other))
1104            __radd__ = __add__
1105
1106        raw = self.prepare_data()
1107        expected = self.func(raw)
1108        for kind in (float, MyFloat, Decimal, Fraction):
1109            data = [kind(x) for x in raw]
1110            result = type(expected)(self.func(data))
1111            self.assertEqual(result, expected)
1112
1113
1114class UnivariateTypeMixin:
1115    """Mixin class for type-conserving functions.
1116
1117    This mixin class holds test(s) for functions which conserve the type of
1118    individual data points. E.g. the mean of a list of Fractions should itself
1119    be a Fraction.
1120
1121    Not all tests to do with types need go in this class. Only those that
1122    rely on the function returning the same type as its input data.
1123    """
1124    def prepare_types_for_conservation_test(self):
1125        """Return the types which are expected to be conserved."""
1126        class MyFloat(float):
1127            def __truediv__(self, other):
1128                return type(self)(super().__truediv__(other))
1129            def __rtruediv__(self, other):
1130                return type(self)(super().__rtruediv__(other))
1131            def __sub__(self, other):
1132                return type(self)(super().__sub__(other))
1133            def __rsub__(self, other):
1134                return type(self)(super().__rsub__(other))
1135            def __pow__(self, other):
1136                return type(self)(super().__pow__(other))
1137            def __add__(self, other):
1138                return type(self)(super().__add__(other))
1139            __radd__ = __add__
1140        return (float, Decimal, Fraction, MyFloat)
1141
1142    def test_types_conserved(self):
1143        # Test that functions keeps the same type as their data points.
1144        # (Excludes mixed data types.) This only tests the type of the return
1145        # result, not the value.
1146        data = self.prepare_data()
1147        for kind in self.prepare_types_for_conservation_test():
1148            d = [kind(x) for x in data]
1149            result = self.func(d)
1150            self.assertIs(type(result), kind)
1151
1152
1153class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin):
1154    # Common test cases for statistics._sum() function.
1155
1156    # This test suite looks only at the numeric value returned by _sum,
1157    # after conversion to the appropriate type.
1158    def setUp(self):
1159        def simplified_sum(*args):
1160            T, value, n = statistics._sum(*args)
1161            return statistics._coerce(value, T)
1162        self.func = simplified_sum
1163
1164
1165class TestSum(NumericTestCase):
1166    # Test cases for statistics._sum() function.
1167
1168    # These tests look at the entire three value tuple returned by _sum.
1169
1170    def setUp(self):
1171        self.func = statistics._sum
1172
1173    def test_empty_data(self):
1174        # Override test for empty data.
1175        for data in ([], (), iter([])):
1176            self.assertEqual(self.func(data), (int, Fraction(0), 0))
1177            self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
1178            self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
1179
1180    def test_ints(self):
1181        self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
1182                         (int, Fraction(60), 8))
1183        self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
1184                         (int, Fraction(1008), 5))
1185
1186    def test_floats(self):
1187        self.assertEqual(self.func([0.25]*20),
1188                         (float, Fraction(5.0), 20))
1189        self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
1190                         (float, Fraction(3.125), 4))
1191
1192    def test_fractions(self):
1193        self.assertEqual(self.func([Fraction(1, 1000)]*500),
1194                         (Fraction, Fraction(1, 2), 500))
1195
1196    def test_decimals(self):
1197        D = Decimal
1198        data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
1199                D("3.974"), D("2.328"), D("4.617"), D("2.843"),
1200                ]
1201        self.assertEqual(self.func(data),
1202                         (Decimal, Decimal("20.686"), 8))
1203
1204    def test_compare_with_math_fsum(self):
1205        # Compare with the math.fsum function.
1206        # Ideally we ought to get the exact same result, but sometimes
1207        # we differ by a very slight amount :-(
1208        data = [random.uniform(-100, 1000) for _ in range(1000)]
1209        self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
1210
1211    def test_start_argument(self):
1212        # Test that the optional start argument works correctly.
1213        data = [random.uniform(1, 1000) for _ in range(100)]
1214        t = self.func(data)[1]
1215        self.assertEqual(t+42, self.func(data, 42)[1])
1216        self.assertEqual(t-23, self.func(data, -23)[1])
1217        self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
1218
1219    def test_strings_fail(self):
1220        # Sum of strings should fail.
1221        self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
1222        self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
1223
1224    def test_bytes_fail(self):
1225        # Sum of bytes should fail.
1226        self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
1227        self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
1228
1229    def test_mixed_sum(self):
1230        # Mixed input types are not (currently) allowed.
1231        # Check that mixed data types fail.
1232        self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)])
1233        # And so does mixed start argument.
1234        self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1))
1235
1236
1237class SumTortureTest(NumericTestCase):
1238    def test_torture(self):
1239        # Tim Peters' torture test for sum, and variants of same.
1240        self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000),
1241                         (float, Fraction(20000.0), 40000))
1242        self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000),
1243                         (float, Fraction(20000.0), 40000))
1244        T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000)
1245        self.assertIs(T, float)
1246        self.assertEqual(count, 40000)
1247        self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16)
1248
1249
1250class SumSpecialValues(NumericTestCase):
1251    # Test that sum works correctly with IEEE-754 special values.
1252
1253    def test_nan(self):
1254        for type_ in (float, Decimal):
1255            nan = type_('nan')
1256            result = statistics._sum([1, nan, 2])[1]
1257            self.assertIs(type(result), type_)
1258            self.assertTrue(math.isnan(result))
1259
1260    def check_infinity(self, x, inf):
1261        """Check x is an infinity of the same type and sign as inf."""
1262        self.assertTrue(math.isinf(x))
1263        self.assertIs(type(x), type(inf))
1264        self.assertEqual(x > 0, inf > 0)
1265        assert x == inf
1266
1267    def do_test_inf(self, inf):
1268        # Adding a single infinity gives infinity.
1269        result = statistics._sum([1, 2, inf, 3])[1]
1270        self.check_infinity(result, inf)
1271        # Adding two infinities of the same sign also gives infinity.
1272        result = statistics._sum([1, 2, inf, 3, inf, 4])[1]
1273        self.check_infinity(result, inf)
1274
1275    def test_float_inf(self):
1276        inf = float('inf')
1277        for sign in (+1, -1):
1278            self.do_test_inf(sign*inf)
1279
1280    def test_decimal_inf(self):
1281        inf = Decimal('inf')
1282        for sign in (+1, -1):
1283            self.do_test_inf(sign*inf)
1284
1285    def test_float_mismatched_infs(self):
1286        # Test that adding two infinities of opposite sign gives a NAN.
1287        inf = float('inf')
1288        result = statistics._sum([1, 2, inf, 3, -inf, 4])[1]
1289        self.assertTrue(math.isnan(result))
1290
1291    def test_decimal_extendedcontext_mismatched_infs_to_nan(self):
1292        # Test adding Decimal INFs with opposite sign returns NAN.
1293        inf = Decimal('inf')
1294        data = [1, 2, inf, 3, -inf, 4]
1295        with decimal.localcontext(decimal.ExtendedContext):
1296            self.assertTrue(math.isnan(statistics._sum(data)[1]))
1297
1298    def test_decimal_basiccontext_mismatched_infs_to_nan(self):
1299        # Test adding Decimal INFs with opposite sign raises InvalidOperation.
1300        inf = Decimal('inf')
1301        data = [1, 2, inf, 3, -inf, 4]
1302        with decimal.localcontext(decimal.BasicContext):
1303            self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1304
1305    def test_decimal_snan_raises(self):
1306        # Adding sNAN should raise InvalidOperation.
1307        sNAN = Decimal('sNAN')
1308        data = [1, sNAN, 2]
1309        self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1310
1311
1312# === Tests for averages ===
1313
1314class AverageMixin(UnivariateCommonMixin):
1315    # Mixin class holding common tests for averages.
1316
1317    def test_single_value(self):
1318        # Average of a single value is the value itself.
1319        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1320            self.assertEqual(self.func([x]), x)
1321
1322    def prepare_values_for_repeated_single_test(self):
1323        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
1324
1325    def test_repeated_single_value(self):
1326        # The average of a single repeated value is the value itself.
1327        for x in self.prepare_values_for_repeated_single_test():
1328            for count in (2, 5, 10, 20):
1329                with self.subTest(x=x, count=count):
1330                    data = [x]*count
1331                    self.assertEqual(self.func(data), x)
1332
1333
1334class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1335    def setUp(self):
1336        self.func = statistics.mean
1337
1338    def test_torture_pep(self):
1339        # "Torture Test" from PEP-450.
1340        self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
1341
1342    def test_ints(self):
1343        # Test mean with ints.
1344        data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
1345        random.shuffle(data)
1346        self.assertEqual(self.func(data), 4.8125)
1347
1348    def test_floats(self):
1349        # Test mean with floats.
1350        data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
1351        random.shuffle(data)
1352        self.assertEqual(self.func(data), 22.015625)
1353
1354    def test_decimals(self):
1355        # Test mean with Decimals.
1356        D = Decimal
1357        data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
1358        random.shuffle(data)
1359        self.assertEqual(self.func(data), D("3.5896"))
1360
1361    def test_fractions(self):
1362        # Test mean with Fractions.
1363        F = Fraction
1364        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1365        random.shuffle(data)
1366        self.assertEqual(self.func(data), F(1479, 1960))
1367
1368    def test_inf(self):
1369        # Test mean with infinities.
1370        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1371        for kind in (float, Decimal):
1372            for sign in (1, -1):
1373                inf = kind("inf")*sign
1374                data = raw + [inf]
1375                result = self.func(data)
1376                self.assertTrue(math.isinf(result))
1377                self.assertEqual(result, inf)
1378
1379    def test_mismatched_infs(self):
1380        # Test mean with infinities of opposite sign.
1381        data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
1382        result = self.func(data)
1383        self.assertTrue(math.isnan(result))
1384
1385    def test_nan(self):
1386        # Test mean with NANs.
1387        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1388        for kind in (float, Decimal):
1389            inf = kind("nan")
1390            data = raw + [inf]
1391            result = self.func(data)
1392            self.assertTrue(math.isnan(result))
1393
1394    def test_big_data(self):
1395        # Test adding a large constant to every data point.
1396        c = 1e9
1397        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1398        expected = self.func(data) + c
1399        assert expected != c
1400        result = self.func([x+c for x in data])
1401        self.assertEqual(result, expected)
1402
1403    def test_doubled_data(self):
1404        # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
1405        data = [random.uniform(-3, 5) for _ in range(1000)]
1406        expected = self.func(data)
1407        actual = self.func(data*2)
1408        self.assertApproxEqual(actual, expected)
1409
1410    def test_regression_20561(self):
1411        # Regression test for issue 20561.
1412        # See http://bugs.python.org/issue20561
1413        d = Decimal('1e4')
1414        self.assertEqual(statistics.mean([d]), d)
1415
1416    def test_regression_25177(self):
1417        # Regression test for issue 25177.
1418        # Ensure very big and very small floats don't overflow.
1419        # See http://bugs.python.org/issue25177.
1420        self.assertEqual(statistics.mean(
1421            [8.988465674311579e+307, 8.98846567431158e+307]),
1422            8.98846567431158e+307)
1423        big = 8.98846567431158e+307
1424        tiny = 5e-324
1425        for n in (2, 3, 5, 200):
1426            self.assertEqual(statistics.mean([big]*n), big)
1427            self.assertEqual(statistics.mean([tiny]*n), tiny)
1428
1429
1430class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1431    def setUp(self):
1432        self.func = statistics.harmonic_mean
1433
1434    def prepare_data(self):
1435        # Override mixin method.
1436        values = super().prepare_data()
1437        values.remove(0)
1438        return values
1439
1440    def prepare_values_for_repeated_single_test(self):
1441        # Override mixin method.
1442        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
1443
1444    def test_zero(self):
1445        # Test that harmonic mean returns zero when given zero.
1446        values = [1, 0, 2]
1447        self.assertEqual(self.func(values), 0)
1448
1449    def test_negative_error(self):
1450        # Test that harmonic mean raises when given a negative value.
1451        exc = statistics.StatisticsError
1452        for values in ([-1], [1, -2, 3]):
1453            with self.subTest(values=values):
1454                self.assertRaises(exc, self.func, values)
1455
1456    def test_ints(self):
1457        # Test harmonic mean with ints.
1458        data = [2, 4, 4, 8, 16, 16]
1459        random.shuffle(data)
1460        self.assertEqual(self.func(data), 6*4/5)
1461
1462    def test_floats_exact(self):
1463        # Test harmonic mean with some carefully chosen floats.
1464        data = [1/8, 1/4, 1/4, 1/2, 1/2]
1465        random.shuffle(data)
1466        self.assertEqual(self.func(data), 1/4)
1467        self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
1468
1469    def test_singleton_lists(self):
1470        # Test that harmonic mean([x]) returns (approximately) x.
1471        for x in range(1, 101):
1472            self.assertEqual(self.func([x]), x)
1473
1474    def test_decimals_exact(self):
1475        # Test harmonic mean with some carefully chosen Decimals.
1476        D = Decimal
1477        self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
1478        data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
1479        random.shuffle(data)
1480        self.assertEqual(self.func(data), D("0.10"))
1481        data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
1482        random.shuffle(data)
1483        self.assertEqual(self.func(data), D(66528)/70723)
1484
1485    def test_fractions(self):
1486        # Test harmonic mean with Fractions.
1487        F = Fraction
1488        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1489        random.shuffle(data)
1490        self.assertEqual(self.func(data), F(7*420, 4029))
1491
1492    def test_inf(self):
1493        # Test harmonic mean with infinity.
1494        values = [2.0, float('inf'), 1.0]
1495        self.assertEqual(self.func(values), 2.0)
1496
1497    def test_nan(self):
1498        # Test harmonic mean with NANs.
1499        values = [2.0, float('nan'), 1.0]
1500        self.assertTrue(math.isnan(self.func(values)))
1501
1502    def test_multiply_data_points(self):
1503        # Test multiplying every data point by a constant.
1504        c = 111
1505        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1506        expected = self.func(data)*c
1507        result = self.func([x*c for x in data])
1508        self.assertEqual(result, expected)
1509
1510    def test_doubled_data(self):
1511        # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
1512        data = [random.uniform(1, 5) for _ in range(1000)]
1513        expected = self.func(data)
1514        actual = self.func(data*2)
1515        self.assertApproxEqual(actual, expected)
1516
1517
1518class TestMedian(NumericTestCase, AverageMixin):
1519    # Common tests for median and all median.* functions.
1520    def setUp(self):
1521        self.func = statistics.median
1522
1523    def prepare_data(self):
1524        """Overload method from UnivariateCommonMixin."""
1525        data = super().prepare_data()
1526        if len(data)%2 != 1:
1527            data.append(2)
1528        return data
1529
1530    def test_even_ints(self):
1531        # Test median with an even number of int data points.
1532        data = [1, 2, 3, 4, 5, 6]
1533        assert len(data)%2 == 0
1534        self.assertEqual(self.func(data), 3.5)
1535
1536    def test_odd_ints(self):
1537        # Test median with an odd number of int data points.
1538        data = [1, 2, 3, 4, 5, 6, 9]
1539        assert len(data)%2 == 1
1540        self.assertEqual(self.func(data), 4)
1541
1542    def test_odd_fractions(self):
1543        # Test median works with an odd number of Fractions.
1544        F = Fraction
1545        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
1546        assert len(data)%2 == 1
1547        random.shuffle(data)
1548        self.assertEqual(self.func(data), F(3, 7))
1549
1550    def test_even_fractions(self):
1551        # Test median works with an even number of Fractions.
1552        F = Fraction
1553        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1554        assert len(data)%2 == 0
1555        random.shuffle(data)
1556        self.assertEqual(self.func(data), F(1, 2))
1557
1558    def test_odd_decimals(self):
1559        # Test median works with an odd number of Decimals.
1560        D = Decimal
1561        data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1562        assert len(data)%2 == 1
1563        random.shuffle(data)
1564        self.assertEqual(self.func(data), D('4.2'))
1565
1566    def test_even_decimals(self):
1567        # Test median works with an even number of Decimals.
1568        D = Decimal
1569        data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1570        assert len(data)%2 == 0
1571        random.shuffle(data)
1572        self.assertEqual(self.func(data), D('3.65'))
1573
1574
1575class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
1576    # Test conservation of data element type for median.
1577    def setUp(self):
1578        self.func = statistics.median
1579
1580    def prepare_data(self):
1581        data = list(range(15))
1582        assert len(data)%2 == 1
1583        while data == sorted(data):
1584            random.shuffle(data)
1585        return data
1586
1587
1588class TestMedianLow(TestMedian, UnivariateTypeMixin):
1589    def setUp(self):
1590        self.func = statistics.median_low
1591
1592    def test_even_ints(self):
1593        # Test median_low with an even number of ints.
1594        data = [1, 2, 3, 4, 5, 6]
1595        assert len(data)%2 == 0
1596        self.assertEqual(self.func(data), 3)
1597
1598    def test_even_fractions(self):
1599        # Test median_low works with an even number of Fractions.
1600        F = Fraction
1601        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1602        assert len(data)%2 == 0
1603        random.shuffle(data)
1604        self.assertEqual(self.func(data), F(3, 7))
1605
1606    def test_even_decimals(self):
1607        # Test median_low works with an even number of Decimals.
1608        D = Decimal
1609        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1610        assert len(data)%2 == 0
1611        random.shuffle(data)
1612        self.assertEqual(self.func(data), D('3.3'))
1613
1614
1615class TestMedianHigh(TestMedian, UnivariateTypeMixin):
1616    def setUp(self):
1617        self.func = statistics.median_high
1618
1619    def test_even_ints(self):
1620        # Test median_high with an even number of ints.
1621        data = [1, 2, 3, 4, 5, 6]
1622        assert len(data)%2 == 0
1623        self.assertEqual(self.func(data), 4)
1624
1625    def test_even_fractions(self):
1626        # Test median_high works with an even number of Fractions.
1627        F = Fraction
1628        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1629        assert len(data)%2 == 0
1630        random.shuffle(data)
1631        self.assertEqual(self.func(data), F(4, 7))
1632
1633    def test_even_decimals(self):
1634        # Test median_high works with an even number of Decimals.
1635        D = Decimal
1636        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1637        assert len(data)%2 == 0
1638        random.shuffle(data)
1639        self.assertEqual(self.func(data), D('4.4'))
1640
1641
1642class TestMedianGrouped(TestMedian):
1643    # Test median_grouped.
1644    # Doesn't conserve data element types, so don't use TestMedianType.
1645    def setUp(self):
1646        self.func = statistics.median_grouped
1647
1648    def test_odd_number_repeated(self):
1649        # Test median.grouped with repeated median values.
1650        data = [12, 13, 14, 14, 14, 15, 15]
1651        assert len(data)%2 == 1
1652        self.assertEqual(self.func(data), 14)
1653        #---
1654        data = [12, 13, 14, 14, 14, 14, 15]
1655        assert len(data)%2 == 1
1656        self.assertEqual(self.func(data), 13.875)
1657        #---
1658        data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
1659        assert len(data)%2 == 1
1660        self.assertEqual(self.func(data, 5), 19.375)
1661        #---
1662        data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
1663        assert len(data)%2 == 1
1664        self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
1665
1666    def test_even_number_repeated(self):
1667        # Test median.grouped with repeated median values.
1668        data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
1669        assert len(data)%2 == 0
1670        self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
1671        #---
1672        data = [2, 3, 4, 4, 4, 5]
1673        assert len(data)%2 == 0
1674        self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
1675        #---
1676        data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1677        assert len(data)%2 == 0
1678        self.assertEqual(self.func(data), 4.5)
1679        #---
1680        data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1681        assert len(data)%2 == 0
1682        self.assertEqual(self.func(data), 4.75)
1683
1684    def test_repeated_single_value(self):
1685        # Override method from AverageMixin.
1686        # Yet again, failure of median_grouped to conserve the data type
1687        # causes me headaches :-(
1688        for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
1689            for count in (2, 5, 10, 20):
1690                data = [x]*count
1691                self.assertEqual(self.func(data), float(x))
1692
1693    def test_odd_fractions(self):
1694        # Test median_grouped works with an odd number of Fractions.
1695        F = Fraction
1696        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
1697        assert len(data)%2 == 1
1698        random.shuffle(data)
1699        self.assertEqual(self.func(data), 3.0)
1700
1701    def test_even_fractions(self):
1702        # Test median_grouped works with an even number of Fractions.
1703        F = Fraction
1704        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
1705        assert len(data)%2 == 0
1706        random.shuffle(data)
1707        self.assertEqual(self.func(data), 3.25)
1708
1709    def test_odd_decimals(self):
1710        # Test median_grouped works with an odd number of Decimals.
1711        D = Decimal
1712        data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1713        assert len(data)%2 == 1
1714        random.shuffle(data)
1715        self.assertEqual(self.func(data), 6.75)
1716
1717    def test_even_decimals(self):
1718        # Test median_grouped works with an even number of Decimals.
1719        D = Decimal
1720        data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1721        assert len(data)%2 == 0
1722        random.shuffle(data)
1723        self.assertEqual(self.func(data), 6.5)
1724        #---
1725        data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
1726        assert len(data)%2 == 0
1727        random.shuffle(data)
1728        self.assertEqual(self.func(data), 7.0)
1729
1730    def test_interval(self):
1731        # Test median_grouped with interval argument.
1732        data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1733        self.assertEqual(self.func(data, 0.25), 2.875)
1734        data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1735        self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
1736        data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
1737        self.assertEqual(self.func(data, 20), 265.0)
1738
1739    def test_data_type_error(self):
1740        # Test median_grouped with str, bytes data types for data and interval
1741        data = ["", "", ""]
1742        self.assertRaises(TypeError, self.func, data)
1743        #---
1744        data = [b"", b"", b""]
1745        self.assertRaises(TypeError, self.func, data)
1746        #---
1747        data = [1, 2, 3]
1748        interval = ""
1749        self.assertRaises(TypeError, self.func, data, interval)
1750        #---
1751        data = [1, 2, 3]
1752        interval = b""
1753        self.assertRaises(TypeError, self.func, data, interval)
1754
1755
1756class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1757    # Test cases for the discrete version of mode.
1758    def setUp(self):
1759        self.func = statistics.mode
1760
1761    def prepare_data(self):
1762        """Overload method from UnivariateCommonMixin."""
1763        # Make sure test data has exactly one mode.
1764        return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
1765
1766    def test_range_data(self):
1767        # Override test from UnivariateCommonMixin.
1768        data = range(20, 50, 3)
1769        self.assertRaises(statistics.StatisticsError, self.func, data)
1770
1771    def test_nominal_data(self):
1772        # Test mode with nominal data.
1773        data = 'abcbdb'
1774        self.assertEqual(self.func(data), 'b')
1775        data = 'fe fi fo fum fi fi'.split()
1776        self.assertEqual(self.func(data), 'fi')
1777
1778    def test_discrete_data(self):
1779        # Test mode with discrete numeric data.
1780        data = list(range(10))
1781        for i in range(10):
1782            d = data + [i]
1783            random.shuffle(d)
1784            self.assertEqual(self.func(d), i)
1785
1786    def test_bimodal_data(self):
1787        # Test mode with bimodal data.
1788        data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
1789        assert data.count(2) == data.count(6) == 4
1790        # Check for an exception.
1791        self.assertRaises(statistics.StatisticsError, self.func, data)
1792
1793    def test_unique_data_failure(self):
1794        # Test mode exception when data points are all unique.
1795        data = list(range(10))
1796        self.assertRaises(statistics.StatisticsError, self.func, data)
1797
1798    def test_none_data(self):
1799        # Test that mode raises TypeError if given None as data.
1800
1801        # This test is necessary because the implementation of mode uses
1802        # collections.Counter, which accepts None and returns an empty dict.
1803        self.assertRaises(TypeError, self.func, None)
1804
1805    def test_counter_data(self):
1806        # Test that a Counter is treated like any other iterable.
1807        data = collections.Counter([1, 1, 1, 2])
1808        # Since the keys of the counter are treated as data points, not the
1809        # counts, this should raise.
1810        self.assertRaises(statistics.StatisticsError, self.func, data)
1811
1812
1813
1814# === Tests for variances and standard deviations ===
1815
1816class VarianceStdevMixin(UnivariateCommonMixin):
1817    # Mixin class holding common tests for variance and std dev.
1818
1819    # Subclasses should inherit from this before NumericTestClass, in order
1820    # to see the rel attribute below. See testShiftData for an explanation.
1821
1822    rel = 1e-12
1823
1824    def test_single_value(self):
1825        # Deviation of a single value is zero.
1826        for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
1827            self.assertEqual(self.func([x]), 0)
1828
1829    def test_repeated_single_value(self):
1830        # The deviation of a single repeated value is zero.
1831        for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
1832            for count in (2, 3, 5, 15):
1833                data = [x]*count
1834                self.assertEqual(self.func(data), 0)
1835
1836    def test_domain_error_regression(self):
1837        # Regression test for a domain error exception.
1838        # (Thanks to Geremy Condra.)
1839        data = [0.123456789012345]*10000
1840        # All the items are identical, so variance should be exactly zero.
1841        # We allow some small round-off error, but not much.
1842        result = self.func(data)
1843        self.assertApproxEqual(result, 0.0, tol=5e-17)
1844        self.assertGreaterEqual(result, 0)  # A negative result must fail.
1845
1846    def test_shift_data(self):
1847        # Test that shifting the data by a constant amount does not affect
1848        # the variance or stdev. Or at least not much.
1849
1850        # Due to rounding, this test should be considered an ideal. We allow
1851        # some tolerance away from "no change at all" by setting tol and/or rel
1852        # attributes. Subclasses may set tighter or looser error tolerances.
1853        raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
1854        expected = self.func(raw)
1855        # Don't set shift too high, the bigger it is, the more rounding error.
1856        shift = 1e5
1857        data = [x + shift for x in raw]
1858        self.assertApproxEqual(self.func(data), expected)
1859
1860    def test_shift_data_exact(self):
1861        # Like test_shift_data, but result is always exact.
1862        raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
1863        assert all(x==int(x) for x in raw)
1864        expected = self.func(raw)
1865        shift = 10**9
1866        data = [x + shift for x in raw]
1867        self.assertEqual(self.func(data), expected)
1868
1869    def test_iter_list_same(self):
1870        # Test that iter data and list data give the same result.
1871
1872        # This is an explicit test that iterators and lists are treated the
1873        # same; justification for this test over and above the similar test
1874        # in UnivariateCommonMixin is that an earlier design had variance and
1875        # friends swap between one- and two-pass algorithms, which would
1876        # sometimes give different results.
1877        data = [random.uniform(-3, 8) for _ in range(1000)]
1878        expected = self.func(data)
1879        self.assertEqual(self.func(iter(data)), expected)
1880
1881
1882class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
1883    # Tests for population variance.
1884    def setUp(self):
1885        self.func = statistics.pvariance
1886
1887    def test_exact_uniform(self):
1888        # Test the variance against an exact result for uniform data.
1889        data = list(range(10000))
1890        random.shuffle(data)
1891        expected = (10000**2 - 1)/12  # Exact value.
1892        self.assertEqual(self.func(data), expected)
1893
1894    def test_ints(self):
1895        # Test population variance with int data.
1896        data = [4, 7, 13, 16]
1897        exact = 22.5
1898        self.assertEqual(self.func(data), exact)
1899
1900    def test_fractions(self):
1901        # Test population variance with Fraction data.
1902        F = Fraction
1903        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
1904        exact = F(3, 8)
1905        result = self.func(data)
1906        self.assertEqual(result, exact)
1907        self.assertIsInstance(result, Fraction)
1908
1909    def test_decimals(self):
1910        # Test population variance with Decimal data.
1911        D = Decimal
1912        data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
1913        exact = D('0.096875')
1914        result = self.func(data)
1915        self.assertEqual(result, exact)
1916        self.assertIsInstance(result, Decimal)
1917
1918
1919class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
1920    # Tests for sample variance.
1921    def setUp(self):
1922        self.func = statistics.variance
1923
1924    def test_single_value(self):
1925        # Override method from VarianceStdevMixin.
1926        for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
1927            self.assertRaises(statistics.StatisticsError, self.func, [x])
1928
1929    def test_ints(self):
1930        # Test sample variance with int data.
1931        data = [4, 7, 13, 16]
1932        exact = 30
1933        self.assertEqual(self.func(data), exact)
1934
1935    def test_fractions(self):
1936        # Test sample variance with Fraction data.
1937        F = Fraction
1938        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
1939        exact = F(1, 2)
1940        result = self.func(data)
1941        self.assertEqual(result, exact)
1942        self.assertIsInstance(result, Fraction)
1943
1944    def test_decimals(self):
1945        # Test sample variance with Decimal data.
1946        D = Decimal
1947        data = [D(2), D(2), D(7), D(9)]
1948        exact = 4*D('9.5')/D(3)
1949        result = self.func(data)
1950        self.assertEqual(result, exact)
1951        self.assertIsInstance(result, Decimal)
1952
1953
1954class TestPStdev(VarianceStdevMixin, NumericTestCase):
1955    # Tests for population standard deviation.
1956    def setUp(self):
1957        self.func = statistics.pstdev
1958
1959    def test_compare_to_variance(self):
1960        # Test that stdev is, in fact, the square root of variance.
1961        data = [random.uniform(-17, 24) for _ in range(1000)]
1962        expected = math.sqrt(statistics.pvariance(data))
1963        self.assertEqual(self.func(data), expected)
1964
1965
1966class TestStdev(VarianceStdevMixin, NumericTestCase):
1967    # Tests for sample standard deviation.
1968    def setUp(self):
1969        self.func = statistics.stdev
1970
1971    def test_single_value(self):
1972        # Override method from VarianceStdevMixin.
1973        for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
1974            self.assertRaises(statistics.StatisticsError, self.func, [x])
1975
1976    def test_compare_to_variance(self):
1977        # Test that stdev is, in fact, the square root of variance.
1978        data = [random.uniform(-2, 9) for _ in range(1000)]
1979        expected = math.sqrt(statistics.variance(data))
1980        self.assertEqual(self.func(data), expected)
1981
1982
1983# === Run tests ===
1984
1985def load_tests(loader, tests, ignore):
1986    """Used for doctest/unittest integration."""
1987    tests.addTests(doctest.DocTestSuite())
1988    return tests
1989
1990
1991if __name__ == "__main__":
1992    unittest.main()
1993