1# Python test set -- part 6, built-in types
2
3from test.support import run_with_locale
4import collections.abc
5import inspect
6import pickle
7import locale
8import sys
9import types
10import unittest.mock
11import weakref
12
13class TypesTests(unittest.TestCase):
14
15    def test_truth_values(self):
16        if None: self.fail('None is true instead of false')
17        if 0: self.fail('0 is true instead of false')
18        if 0.0: self.fail('0.0 is true instead of false')
19        if '': self.fail('\'\' is true instead of false')
20        if not 1: self.fail('1 is false instead of true')
21        if not 1.0: self.fail('1.0 is false instead of true')
22        if not 'x': self.fail('\'x\' is false instead of true')
23        if not {'x': 1}: self.fail('{\'x\': 1} is false instead of true')
24        def f(): pass
25        class C: pass
26        x = C()
27        if not f: self.fail('f is false instead of true')
28        if not C: self.fail('C is false instead of true')
29        if not sys: self.fail('sys is false instead of true')
30        if not x: self.fail('x is false instead of true')
31
32    def test_boolean_ops(self):
33        if 0 or 0: self.fail('0 or 0 is true instead of false')
34        if 1 and 1: pass
35        else: self.fail('1 and 1 is false instead of true')
36        if not 1: self.fail('not 1 is true instead of false')
37
38    def test_comparisons(self):
39        if 0 < 1 <= 1 == 1 >= 1 > 0 != 1: pass
40        else: self.fail('int comparisons failed')
41        if 0.0 < 1.0 <= 1.0 == 1.0 >= 1.0 > 0.0 != 1.0: pass
42        else: self.fail('float comparisons failed')
43        if '' < 'a' <= 'a' == 'a' < 'abc' < 'abd' < 'b': pass
44        else: self.fail('string comparisons failed')
45        if None is None: pass
46        else: self.fail('identity test failed')
47
48    def test_float_constructor(self):
49        self.assertRaises(ValueError, float, '')
50        self.assertRaises(ValueError, float, '5\0')
51        self.assertRaises(ValueError, float, '5_5\0')
52
53    def test_zero_division(self):
54        try: 5.0 / 0.0
55        except ZeroDivisionError: pass
56        else: self.fail("5.0 / 0.0 didn't raise ZeroDivisionError")
57
58        try: 5.0 // 0.0
59        except ZeroDivisionError: pass
60        else: self.fail("5.0 // 0.0 didn't raise ZeroDivisionError")
61
62        try: 5.0 % 0.0
63        except ZeroDivisionError: pass
64        else: self.fail("5.0 % 0.0 didn't raise ZeroDivisionError")
65
66        try: 5 / 0
67        except ZeroDivisionError: pass
68        else: self.fail("5 / 0 didn't raise ZeroDivisionError")
69
70        try: 5 // 0
71        except ZeroDivisionError: pass
72        else: self.fail("5 // 0 didn't raise ZeroDivisionError")
73
74        try: 5 % 0
75        except ZeroDivisionError: pass
76        else: self.fail("5 % 0 didn't raise ZeroDivisionError")
77
78    def test_numeric_types(self):
79        if 0 != 0.0 or 1 != 1.0 or -1 != -1.0:
80            self.fail('int/float value not equal')
81        # calling built-in types without argument must return 0
82        if int() != 0: self.fail('int() does not return 0')
83        if float() != 0.0: self.fail('float() does not return 0.0')
84        if int(1.9) == 1 == int(1.1) and int(-1.1) == -1 == int(-1.9): pass
85        else: self.fail('int() does not round properly')
86        if float(1) == 1.0 and float(-1) == -1.0 and float(0) == 0.0: pass
87        else: self.fail('float() does not work properly')
88
89    def test_float_to_string(self):
90        def test(f, result):
91            self.assertEqual(f.__format__('e'), result)
92            self.assertEqual('%e' % f, result)
93
94        # test all 2 digit exponents, both with __format__ and with
95        #  '%' formatting
96        for i in range(-99, 100):
97            test(float('1.5e'+str(i)), '1.500000e{0:+03d}'.format(i))
98
99        # test some 3 digit exponents
100        self.assertEqual(1.5e100.__format__('e'), '1.500000e+100')
101        self.assertEqual('%e' % 1.5e100, '1.500000e+100')
102
103        self.assertEqual(1.5e101.__format__('e'), '1.500000e+101')
104        self.assertEqual('%e' % 1.5e101, '1.500000e+101')
105
106        self.assertEqual(1.5e-100.__format__('e'), '1.500000e-100')
107        self.assertEqual('%e' % 1.5e-100, '1.500000e-100')
108
109        self.assertEqual(1.5e-101.__format__('e'), '1.500000e-101')
110        self.assertEqual('%e' % 1.5e-101, '1.500000e-101')
111
112        self.assertEqual('%g' % 1.0, '1')
113        self.assertEqual('%#g' % 1.0, '1.00000')
114
115    def test_normal_integers(self):
116        # Ensure the first 256 integers are shared
117        a = 256
118        b = 128*2
119        if a is not b: self.fail('256 is not shared')
120        if 12 + 24 != 36: self.fail('int op')
121        if 12 + (-24) != -12: self.fail('int op')
122        if (-12) + 24 != 12: self.fail('int op')
123        if (-12) + (-24) != -36: self.fail('int op')
124        if not 12 < 24: self.fail('int op')
125        if not -24 < -12: self.fail('int op')
126        # Test for a particular bug in integer multiply
127        xsize, ysize, zsize = 238, 356, 4
128        if not (xsize*ysize*zsize == zsize*xsize*ysize == 338912):
129            self.fail('int mul commutativity')
130        # And another.
131        m = -sys.maxsize - 1
132        for divisor in 1, 2, 4, 8, 16, 32:
133            j = m // divisor
134            prod = divisor * j
135            if prod != m:
136                self.fail("%r * %r == %r != %r" % (divisor, j, prod, m))
137            if type(prod) is not int:
138                self.fail("expected type(prod) to be int, not %r" %
139                                   type(prod))
140        # Check for unified integral type
141        for divisor in 1, 2, 4, 8, 16, 32:
142            j = m // divisor - 1
143            prod = divisor * j
144            if type(prod) is not int:
145                self.fail("expected type(%r) to be int, not %r" %
146                                   (prod, type(prod)))
147        # Check for unified integral type
148        m = sys.maxsize
149        for divisor in 1, 2, 4, 8, 16, 32:
150            j = m // divisor + 1
151            prod = divisor * j
152            if type(prod) is not int:
153                self.fail("expected type(%r) to be int, not %r" %
154                                   (prod, type(prod)))
155
156        x = sys.maxsize
157        self.assertIsInstance(x + 1, int,
158                              "(sys.maxsize + 1) should have returned int")
159        self.assertIsInstance(-x - 1, int,
160                              "(-sys.maxsize - 1) should have returned int")
161        self.assertIsInstance(-x - 2, int,
162                              "(-sys.maxsize - 2) should have returned int")
163
164        try: 5 << -5
165        except ValueError: pass
166        else: self.fail('int negative shift <<')
167
168        try: 5 >> -5
169        except ValueError: pass
170        else: self.fail('int negative shift >>')
171
172    def test_floats(self):
173        if 12.0 + 24.0 != 36.0: self.fail('float op')
174        if 12.0 + (-24.0) != -12.0: self.fail('float op')
175        if (-12.0) + 24.0 != 12.0: self.fail('float op')
176        if (-12.0) + (-24.0) != -36.0: self.fail('float op')
177        if not 12.0 < 24.0: self.fail('float op')
178        if not -24.0 < -12.0: self.fail('float op')
179
180    def test_strings(self):
181        if len('') != 0: self.fail('len(\'\')')
182        if len('a') != 1: self.fail('len(\'a\')')
183        if len('abcdef') != 6: self.fail('len(\'abcdef\')')
184        if 'xyz' + 'abcde' != 'xyzabcde': self.fail('string concatenation')
185        if 'xyz'*3 != 'xyzxyzxyz': self.fail('string repetition *3')
186        if 0*'abcde' != '': self.fail('string repetition 0*')
187        if min('abc') != 'a' or max('abc') != 'c': self.fail('min/max string')
188        if 'a' in 'abc' and 'b' in 'abc' and 'c' in 'abc' and 'd' not in 'abc': pass
189        else: self.fail('in/not in string')
190        x = 'x'*103
191        if '%s!'%x != x+'!': self.fail('nasty string formatting bug')
192
193        #extended slices for strings
194        a = '0123456789'
195        self.assertEqual(a[::], a)
196        self.assertEqual(a[::2], '02468')
197        self.assertEqual(a[1::2], '13579')
198        self.assertEqual(a[::-1],'9876543210')
199        self.assertEqual(a[::-2], '97531')
200        self.assertEqual(a[3::-2], '31')
201        self.assertEqual(a[-100:100:], a)
202        self.assertEqual(a[100:-100:-1], a[::-1])
203        self.assertEqual(a[-100:100:2], '02468')
204
205    def test_type_function(self):
206        self.assertRaises(TypeError, type, 1, 2)
207        self.assertRaises(TypeError, type, 1, 2, 3, 4)
208
209    def test_int__format__(self):
210        def test(i, format_spec, result):
211            # just make sure we have the unified type for integers
212            assert type(i) == int
213            assert type(format_spec) == str
214            self.assertEqual(i.__format__(format_spec), result)
215
216        test(123456789, 'd', '123456789')
217        test(123456789, 'd', '123456789')
218
219        test(1, 'c', '\01')
220
221        # sign and aligning are interdependent
222        test(1, "-", '1')
223        test(-1, "-", '-1')
224        test(1, "-3", '  1')
225        test(-1, "-3", ' -1')
226        test(1, "+3", ' +1')
227        test(-1, "+3", ' -1')
228        test(1, " 3", '  1')
229        test(-1, " 3", ' -1')
230        test(1, " ", ' 1')
231        test(-1, " ", '-1')
232
233        # hex
234        test(3, "x", "3")
235        test(3, "X", "3")
236        test(1234, "x", "4d2")
237        test(-1234, "x", "-4d2")
238        test(1234, "8x", "     4d2")
239        test(-1234, "8x", "    -4d2")
240        test(1234, "x", "4d2")
241        test(-1234, "x", "-4d2")
242        test(-3, "x", "-3")
243        test(-3, "X", "-3")
244        test(int('be', 16), "x", "be")
245        test(int('be', 16), "X", "BE")
246        test(-int('be', 16), "x", "-be")
247        test(-int('be', 16), "X", "-BE")
248
249        # octal
250        test(3, "o", "3")
251        test(-3, "o", "-3")
252        test(65, "o", "101")
253        test(-65, "o", "-101")
254        test(1234, "o", "2322")
255        test(-1234, "o", "-2322")
256        test(1234, "-o", "2322")
257        test(-1234, "-o", "-2322")
258        test(1234, " o", " 2322")
259        test(-1234, " o", "-2322")
260        test(1234, "+o", "+2322")
261        test(-1234, "+o", "-2322")
262
263        # binary
264        test(3, "b", "11")
265        test(-3, "b", "-11")
266        test(1234, "b", "10011010010")
267        test(-1234, "b", "-10011010010")
268        test(1234, "-b", "10011010010")
269        test(-1234, "-b", "-10011010010")
270        test(1234, " b", " 10011010010")
271        test(-1234, " b", "-10011010010")
272        test(1234, "+b", "+10011010010")
273        test(-1234, "+b", "-10011010010")
274
275        # alternate (#) formatting
276        test(0, "#b", '0b0')
277        test(0, "-#b", '0b0')
278        test(1, "-#b", '0b1')
279        test(-1, "-#b", '-0b1')
280        test(-1, "-#5b", ' -0b1')
281        test(1, "+#5b", ' +0b1')
282        test(100, "+#b", '+0b1100100')
283        test(100, "#012b", '0b0001100100')
284        test(-100, "#012b", '-0b001100100')
285
286        test(0, "#o", '0o0')
287        test(0, "-#o", '0o0')
288        test(1, "-#o", '0o1')
289        test(-1, "-#o", '-0o1')
290        test(-1, "-#5o", ' -0o1')
291        test(1, "+#5o", ' +0o1')
292        test(100, "+#o", '+0o144')
293        test(100, "#012o", '0o0000000144')
294        test(-100, "#012o", '-0o000000144')
295
296        test(0, "#x", '0x0')
297        test(0, "-#x", '0x0')
298        test(1, "-#x", '0x1')
299        test(-1, "-#x", '-0x1')
300        test(-1, "-#5x", ' -0x1')
301        test(1, "+#5x", ' +0x1')
302        test(100, "+#x", '+0x64')
303        test(100, "#012x", '0x0000000064')
304        test(-100, "#012x", '-0x000000064')
305        test(123456, "#012x", '0x000001e240')
306        test(-123456, "#012x", '-0x00001e240')
307
308        test(0, "#X", '0X0')
309        test(0, "-#X", '0X0')
310        test(1, "-#X", '0X1')
311        test(-1, "-#X", '-0X1')
312        test(-1, "-#5X", ' -0X1')
313        test(1, "+#5X", ' +0X1')
314        test(100, "+#X", '+0X64')
315        test(100, "#012X", '0X0000000064')
316        test(-100, "#012X", '-0X000000064')
317        test(123456, "#012X", '0X000001E240')
318        test(-123456, "#012X", '-0X00001E240')
319
320        test(123, ',', '123')
321        test(-123, ',', '-123')
322        test(1234, ',', '1,234')
323        test(-1234, ',', '-1,234')
324        test(123456, ',', '123,456')
325        test(-123456, ',', '-123,456')
326        test(1234567, ',', '1,234,567')
327        test(-1234567, ',', '-1,234,567')
328
329        # issue 5782, commas with no specifier type
330        test(1234, '010,', '00,001,234')
331
332        # Unified type for integers
333        test(10**100, 'd', '1' + '0' * 100)
334        test(10**100+100, 'd', '1' + '0' * 97 + '100')
335
336        # make sure these are errors
337
338        # precision disallowed
339        self.assertRaises(ValueError, 3 .__format__, "1.3")
340        # sign not allowed with 'c'
341        self.assertRaises(ValueError, 3 .__format__, "+c")
342        # format spec must be string
343        self.assertRaises(TypeError, 3 .__format__, None)
344        self.assertRaises(TypeError, 3 .__format__, 0)
345        # can't have ',' with 'n'
346        self.assertRaises(ValueError, 3 .__format__, ",n")
347        # can't have ',' with 'c'
348        self.assertRaises(ValueError, 3 .__format__, ",c")
349        # can't have '#' with 'c'
350        self.assertRaises(ValueError, 3 .__format__, "#c")
351
352        # ensure that only int and float type specifiers work
353        for format_spec in ([chr(x) for x in range(ord('a'), ord('z')+1)] +
354                            [chr(x) for x in range(ord('A'), ord('Z')+1)]):
355            if not format_spec in 'bcdoxXeEfFgGn%':
356                self.assertRaises(ValueError, 0 .__format__, format_spec)
357                self.assertRaises(ValueError, 1 .__format__, format_spec)
358                self.assertRaises(ValueError, (-1) .__format__, format_spec)
359
360        # ensure that float type specifiers work; format converts
361        #  the int to a float
362        for format_spec in 'eEfFgG%':
363            for value in [0, 1, -1, 100, -100, 1234567890, -1234567890]:
364                self.assertEqual(value.__format__(format_spec),
365                                 float(value).__format__(format_spec))
366
367        # Issue 6902
368        test(123456, "0<20", '12345600000000000000')
369        test(123456, "1<20", '12345611111111111111')
370        test(123456, "*<20", '123456**************')
371        test(123456, "0>20", '00000000000000123456')
372        test(123456, "1>20", '11111111111111123456')
373        test(123456, "*>20", '**************123456')
374        test(123456, "0=20", '00000000000000123456')
375        test(123456, "1=20", '11111111111111123456')
376        test(123456, "*=20", '**************123456')
377
378    @run_with_locale('LC_NUMERIC', 'en_US.UTF8')
379    def test_float__format__locale(self):
380        # test locale support for __format__ code 'n'
381
382        for i in range(-10, 10):
383            x = 1234567890.0 * (10.0 ** i)
384            self.assertEqual(locale.format('%g', x, grouping=True), format(x, 'n'))
385            self.assertEqual(locale.format('%.10g', x, grouping=True), format(x, '.10n'))
386
387    @run_with_locale('LC_NUMERIC', 'en_US.UTF8')
388    def test_int__format__locale(self):
389        # test locale support for __format__ code 'n' for integers
390
391        x = 123456789012345678901234567890
392        for i in range(0, 30):
393            self.assertEqual(locale.format('%d', x, grouping=True), format(x, 'n'))
394
395            # move to the next integer to test
396            x = x // 10
397
398        rfmt = ">20n"
399        lfmt = "<20n"
400        cfmt = "^20n"
401        for x in (1234, 12345, 123456, 1234567, 12345678, 123456789, 1234567890, 12345678900):
402            self.assertEqual(len(format(0, rfmt)), len(format(x, rfmt)))
403            self.assertEqual(len(format(0, lfmt)), len(format(x, lfmt)))
404            self.assertEqual(len(format(0, cfmt)), len(format(x, cfmt)))
405
406    def test_float__format__(self):
407        def test(f, format_spec, result):
408            self.assertEqual(f.__format__(format_spec), result)
409            self.assertEqual(format(f, format_spec), result)
410
411        test(0.0, 'f', '0.000000')
412
413        # the default is 'g', except for empty format spec
414        test(0.0, '', '0.0')
415        test(0.01, '', '0.01')
416        test(0.01, 'g', '0.01')
417
418        # test for issue 3411
419        test(1.23, '1', '1.23')
420        test(-1.23, '1', '-1.23')
421        test(1.23, '1g', '1.23')
422        test(-1.23, '1g', '-1.23')
423
424        test( 1.0, ' g', ' 1')
425        test(-1.0, ' g', '-1')
426        test( 1.0, '+g', '+1')
427        test(-1.0, '+g', '-1')
428        test(1.1234e200, 'g', '1.1234e+200')
429        test(1.1234e200, 'G', '1.1234E+200')
430
431
432        test(1.0, 'f', '1.000000')
433
434        test(-1.0, 'f', '-1.000000')
435
436        test( 1.0, ' f', ' 1.000000')
437        test(-1.0, ' f', '-1.000000')
438        test( 1.0, '+f', '+1.000000')
439        test(-1.0, '+f', '-1.000000')
440
441        # Python versions <= 3.0 switched from 'f' to 'g' formatting for
442        # values larger than 1e50.  No longer.
443        f = 1.1234e90
444        for fmt in 'f', 'F':
445            # don't do a direct equality check, since on some
446            # platforms only the first few digits of dtoa
447            # will be reliable
448            result = f.__format__(fmt)
449            self.assertEqual(len(result), 98)
450            self.assertEqual(result[-7], '.')
451            self.assertIn(result[:12], ('112340000000', '112339999999'))
452        f = 1.1234e200
453        for fmt in 'f', 'F':
454            result = f.__format__(fmt)
455            self.assertEqual(len(result), 208)
456            self.assertEqual(result[-7], '.')
457            self.assertIn(result[:12], ('112340000000', '112339999999'))
458
459
460        test( 1.0, 'e', '1.000000e+00')
461        test(-1.0, 'e', '-1.000000e+00')
462        test( 1.0, 'E', '1.000000E+00')
463        test(-1.0, 'E', '-1.000000E+00')
464        test(1.1234e20, 'e', '1.123400e+20')
465        test(1.1234e20, 'E', '1.123400E+20')
466
467        # No format code means use g, but must have a decimal
468        # and a number after the decimal.  This is tricky, because
469        # a totaly empty format specifier means something else.
470        # So, just use a sign flag
471        test(1e200, '+g', '+1e+200')
472        test(1e200, '+', '+1e+200')
473
474        test(1.1e200, '+g', '+1.1e+200')
475        test(1.1e200, '+', '+1.1e+200')
476
477        # 0 padding
478        test(1234., '010f', '1234.000000')
479        test(1234., '011f', '1234.000000')
480        test(1234., '012f', '01234.000000')
481        test(-1234., '011f', '-1234.000000')
482        test(-1234., '012f', '-1234.000000')
483        test(-1234., '013f', '-01234.000000')
484        test(-1234.12341234, '013f', '-01234.123412')
485        test(-123456.12341234, '011.2f', '-0123456.12')
486
487        # issue 5782, commas with no specifier type
488        test(1.2, '010,.2', '0,000,001.2')
489
490        # 0 padding with commas
491        test(1234., '011,f', '1,234.000000')
492        test(1234., '012,f', '1,234.000000')
493        test(1234., '013,f', '01,234.000000')
494        test(-1234., '012,f', '-1,234.000000')
495        test(-1234., '013,f', '-1,234.000000')
496        test(-1234., '014,f', '-01,234.000000')
497        test(-12345., '015,f', '-012,345.000000')
498        test(-123456., '016,f', '-0,123,456.000000')
499        test(-123456., '017,f', '-0,123,456.000000')
500        test(-123456.12341234, '017,f', '-0,123,456.123412')
501        test(-123456.12341234, '013,.2f', '-0,123,456.12')
502
503        # % formatting
504        test(-1.0, '%', '-100.000000%')
505
506        # format spec must be string
507        self.assertRaises(TypeError, 3.0.__format__, None)
508        self.assertRaises(TypeError, 3.0.__format__, 0)
509
510        # other format specifiers shouldn't work on floats,
511        #  in particular int specifiers
512        for format_spec in ([chr(x) for x in range(ord('a'), ord('z')+1)] +
513                            [chr(x) for x in range(ord('A'), ord('Z')+1)]):
514            if not format_spec in 'eEfFgGn%':
515                self.assertRaises(ValueError, format, 0.0, format_spec)
516                self.assertRaises(ValueError, format, 1.0, format_spec)
517                self.assertRaises(ValueError, format, -1.0, format_spec)
518                self.assertRaises(ValueError, format, 1e100, format_spec)
519                self.assertRaises(ValueError, format, -1e100, format_spec)
520                self.assertRaises(ValueError, format, 1e-100, format_spec)
521                self.assertRaises(ValueError, format, -1e-100, format_spec)
522
523        # Alternate float formatting
524        test(1.0, '.0e', '1e+00')
525        test(1.0, '#.0e', '1.e+00')
526        test(1.0, '.0f', '1')
527        test(1.0, '#.0f', '1.')
528        test(1.1, 'g', '1.1')
529        test(1.1, '#g', '1.10000')
530        test(1.0, '.0%', '100%')
531        test(1.0, '#.0%', '100.%')
532
533        # Issue 7094: Alternate formatting (specified by #)
534        test(1.0, '0e',  '1.000000e+00')
535        test(1.0, '#0e', '1.000000e+00')
536        test(1.0, '0f',  '1.000000' )
537        test(1.0, '#0f', '1.000000')
538        test(1.0, '.1e',  '1.0e+00')
539        test(1.0, '#.1e', '1.0e+00')
540        test(1.0, '.1f',  '1.0')
541        test(1.0, '#.1f', '1.0')
542        test(1.0, '.1%',  '100.0%')
543        test(1.0, '#.1%', '100.0%')
544
545        # Issue 6902
546        test(12345.6, "0<20", '12345.60000000000000')
547        test(12345.6, "1<20", '12345.61111111111111')
548        test(12345.6, "*<20", '12345.6*************')
549        test(12345.6, "0>20", '000000000000012345.6')
550        test(12345.6, "1>20", '111111111111112345.6')
551        test(12345.6, "*>20", '*************12345.6')
552        test(12345.6, "0=20", '000000000000012345.6')
553        test(12345.6, "1=20", '111111111111112345.6')
554        test(12345.6, "*=20", '*************12345.6')
555
556    def test_format_spec_errors(self):
557        # int, float, and string all share the same format spec
558        # mini-language parser.
559
560        # Check that we can't ask for too many digits. This is
561        # probably a CPython specific test. It tries to put the width
562        # into a C long.
563        self.assertRaises(ValueError, format, 0, '1'*10000 + 'd')
564
565        # Similar with the precision.
566        self.assertRaises(ValueError, format, 0, '.' + '1'*10000 + 'd')
567
568        # And may as well test both.
569        self.assertRaises(ValueError, format, 0, '1'*1000 + '.' + '1'*10000 + 'd')
570
571        # Make sure commas aren't allowed with various type codes
572        for code in 'xXobns':
573            self.assertRaises(ValueError, format, 0, ',' + code)
574
575    def test_internal_sizes(self):
576        self.assertGreater(object.__basicsize__, 0)
577        self.assertGreater(tuple.__itemsize__, 0)
578
579
580class MappingProxyTests(unittest.TestCase):
581    mappingproxy = types.MappingProxyType
582
583    def test_constructor(self):
584        class userdict(dict):
585            pass
586
587        mapping = {'x': 1, 'y': 2}
588        self.assertEqual(self.mappingproxy(mapping), mapping)
589        mapping = userdict(x=1, y=2)
590        self.assertEqual(self.mappingproxy(mapping), mapping)
591        mapping = collections.ChainMap({'x': 1}, {'y': 2})
592        self.assertEqual(self.mappingproxy(mapping), mapping)
593
594        self.assertRaises(TypeError, self.mappingproxy, 10)
595        self.assertRaises(TypeError, self.mappingproxy, ("a", "tuple"))
596        self.assertRaises(TypeError, self.mappingproxy, ["a", "list"])
597
598    def test_methods(self):
599        attrs = set(dir(self.mappingproxy({}))) - set(dir(object()))
600        self.assertEqual(attrs, {
601             '__contains__',
602             '__getitem__',
603             '__iter__',
604             '__len__',
605             'copy',
606             'get',
607             'items',
608             'keys',
609             'values',
610        })
611
612    def test_get(self):
613        view = self.mappingproxy({'a': 'A', 'b': 'B'})
614        self.assertEqual(view['a'], 'A')
615        self.assertEqual(view['b'], 'B')
616        self.assertRaises(KeyError, view.__getitem__, 'xxx')
617        self.assertEqual(view.get('a'), 'A')
618        self.assertIsNone(view.get('xxx'))
619        self.assertEqual(view.get('xxx', 42), 42)
620
621    def test_missing(self):
622        class dictmissing(dict):
623            def __missing__(self, key):
624                return "missing=%s" % key
625
626        view = self.mappingproxy(dictmissing(x=1))
627        self.assertEqual(view['x'], 1)
628        self.assertEqual(view['y'], 'missing=y')
629        self.assertEqual(view.get('x'), 1)
630        self.assertEqual(view.get('y'), None)
631        self.assertEqual(view.get('y', 42), 42)
632        self.assertTrue('x' in view)
633        self.assertFalse('y' in view)
634
635    def test_customdict(self):
636        class customdict(dict):
637            def __contains__(self, key):
638                if key == 'magic':
639                    return True
640                else:
641                    return dict.__contains__(self, key)
642
643            def __iter__(self):
644                return iter(('iter',))
645
646            def __len__(self):
647                return 500
648
649            def copy(self):
650                return 'copy'
651
652            def keys(self):
653                return 'keys'
654
655            def items(self):
656                return 'items'
657
658            def values(self):
659                return 'values'
660
661            def __getitem__(self, key):
662                return "getitem=%s" % dict.__getitem__(self, key)
663
664            def get(self, key, default=None):
665                return "get=%s" % dict.get(self, key, 'default=%r' % default)
666
667        custom = customdict({'key': 'value'})
668        view = self.mappingproxy(custom)
669        self.assertTrue('key' in view)
670        self.assertTrue('magic' in view)
671        self.assertFalse('xxx' in view)
672        self.assertEqual(view['key'], 'getitem=value')
673        self.assertRaises(KeyError, view.__getitem__, 'xxx')
674        self.assertEqual(tuple(view), ('iter',))
675        self.assertEqual(len(view), 500)
676        self.assertEqual(view.copy(), 'copy')
677        self.assertEqual(view.get('key'), 'get=value')
678        self.assertEqual(view.get('xxx'), 'get=default=None')
679        self.assertEqual(view.items(), 'items')
680        self.assertEqual(view.keys(), 'keys')
681        self.assertEqual(view.values(), 'values')
682
683    def test_chainmap(self):
684        d1 = {'x': 1}
685        d2 = {'y': 2}
686        mapping = collections.ChainMap(d1, d2)
687        view = self.mappingproxy(mapping)
688        self.assertTrue('x' in view)
689        self.assertTrue('y' in view)
690        self.assertFalse('z' in view)
691        self.assertEqual(view['x'], 1)
692        self.assertEqual(view['y'], 2)
693        self.assertRaises(KeyError, view.__getitem__, 'z')
694        self.assertEqual(tuple(sorted(view)), ('x', 'y'))
695        self.assertEqual(len(view), 2)
696        copy = view.copy()
697        self.assertIsNot(copy, mapping)
698        self.assertIsInstance(copy, collections.ChainMap)
699        self.assertEqual(copy, mapping)
700        self.assertEqual(view.get('x'), 1)
701        self.assertEqual(view.get('y'), 2)
702        self.assertIsNone(view.get('z'))
703        self.assertEqual(tuple(sorted(view.items())), (('x', 1), ('y', 2)))
704        self.assertEqual(tuple(sorted(view.keys())), ('x', 'y'))
705        self.assertEqual(tuple(sorted(view.values())), (1, 2))
706
707    def test_contains(self):
708        view = self.mappingproxy(dict.fromkeys('abc'))
709        self.assertTrue('a' in view)
710        self.assertTrue('b' in view)
711        self.assertTrue('c' in view)
712        self.assertFalse('xxx' in view)
713
714    def test_views(self):
715        mapping = {}
716        view = self.mappingproxy(mapping)
717        keys = view.keys()
718        values = view.values()
719        items = view.items()
720        self.assertEqual(list(keys), [])
721        self.assertEqual(list(values), [])
722        self.assertEqual(list(items), [])
723        mapping['key'] = 'value'
724        self.assertEqual(list(keys), ['key'])
725        self.assertEqual(list(values), ['value'])
726        self.assertEqual(list(items), [('key', 'value')])
727
728    def test_len(self):
729        for expected in range(6):
730            data = dict.fromkeys('abcde'[:expected])
731            self.assertEqual(len(data), expected)
732            view = self.mappingproxy(data)
733            self.assertEqual(len(view), expected)
734
735    def test_iterators(self):
736        keys = ('x', 'y')
737        values = (1, 2)
738        items = tuple(zip(keys, values))
739        view = self.mappingproxy(dict(items))
740        self.assertEqual(set(view), set(keys))
741        self.assertEqual(set(view.keys()), set(keys))
742        self.assertEqual(set(view.values()), set(values))
743        self.assertEqual(set(view.items()), set(items))
744
745    def test_copy(self):
746        original = {'key1': 27, 'key2': 51, 'key3': 93}
747        view = self.mappingproxy(original)
748        copy = view.copy()
749        self.assertEqual(type(copy), dict)
750        self.assertEqual(copy, original)
751        original['key1'] = 70
752        self.assertEqual(view['key1'], 70)
753        self.assertEqual(copy['key1'], 27)
754
755
756class ClassCreationTests(unittest.TestCase):
757
758    class Meta(type):
759        def __init__(cls, name, bases, ns, **kw):
760            super().__init__(name, bases, ns)
761        @staticmethod
762        def __new__(mcls, name, bases, ns, **kw):
763            return super().__new__(mcls, name, bases, ns)
764        @classmethod
765        def __prepare__(mcls, name, bases, **kw):
766            ns = super().__prepare__(name, bases)
767            ns["y"] = 1
768            ns.update(kw)
769            return ns
770
771    def test_new_class_basics(self):
772        C = types.new_class("C")
773        self.assertEqual(C.__name__, "C")
774        self.assertEqual(C.__bases__, (object,))
775
776    def test_new_class_subclass(self):
777        C = types.new_class("C", (int,))
778        self.assertTrue(issubclass(C, int))
779
780    def test_new_class_meta(self):
781        Meta = self.Meta
782        settings = {"metaclass": Meta, "z": 2}
783        # We do this twice to make sure the passed in dict isn't mutated
784        for i in range(2):
785            C = types.new_class("C" + str(i), (), settings)
786            self.assertIsInstance(C, Meta)
787            self.assertEqual(C.y, 1)
788            self.assertEqual(C.z, 2)
789
790    def test_new_class_exec_body(self):
791        Meta = self.Meta
792        def func(ns):
793            ns["x"] = 0
794        C = types.new_class("C", (), {"metaclass": Meta, "z": 2}, func)
795        self.assertIsInstance(C, Meta)
796        self.assertEqual(C.x, 0)
797        self.assertEqual(C.y, 1)
798        self.assertEqual(C.z, 2)
799
800    def test_new_class_metaclass_keywords(self):
801        #Test that keywords are passed to the metaclass:
802        def meta_func(name, bases, ns, **kw):
803            return name, bases, ns, kw
804        res = types.new_class("X",
805                              (int, object),
806                              dict(metaclass=meta_func, x=0))
807        self.assertEqual(res, ("X", (int, object), {}, {"x": 0}))
808
809    def test_new_class_defaults(self):
810        # Test defaults/keywords:
811        C = types.new_class("C", (), {}, None)
812        self.assertEqual(C.__name__, "C")
813        self.assertEqual(C.__bases__, (object,))
814
815    def test_new_class_meta_with_base(self):
816        Meta = self.Meta
817        def func(ns):
818            ns["x"] = 0
819        C = types.new_class(name="C",
820                            bases=(int,),
821                            kwds=dict(metaclass=Meta, z=2),
822                            exec_body=func)
823        self.assertTrue(issubclass(C, int))
824        self.assertIsInstance(C, Meta)
825        self.assertEqual(C.x, 0)
826        self.assertEqual(C.y, 1)
827        self.assertEqual(C.z, 2)
828
829    # Many of the following tests are derived from test_descr.py
830    def test_prepare_class(self):
831        # Basic test of metaclass derivation
832        expected_ns = {}
833        class A(type):
834            def __new__(*args, **kwargs):
835                return type.__new__(*args, **kwargs)
836
837            def __prepare__(*args):
838                return expected_ns
839
840        B = types.new_class("B", (object,))
841        C = types.new_class("C", (object,), {"metaclass": A})
842
843        # The most derived metaclass of D is A rather than type.
844        meta, ns, kwds = types.prepare_class("D", (B, C), {"metaclass": type})
845        self.assertIs(meta, A)
846        self.assertIs(ns, expected_ns)
847        self.assertEqual(len(kwds), 0)
848
849    def test_metaclass_derivation(self):
850        # issue1294232: correct metaclass calculation
851        new_calls = []  # to check the order of __new__ calls
852        class AMeta(type):
853            def __new__(mcls, name, bases, ns):
854                new_calls.append('AMeta')
855                return super().__new__(mcls, name, bases, ns)
856            @classmethod
857            def __prepare__(mcls, name, bases):
858                return {}
859
860        class BMeta(AMeta):
861            def __new__(mcls, name, bases, ns):
862                new_calls.append('BMeta')
863                return super().__new__(mcls, name, bases, ns)
864            @classmethod
865            def __prepare__(mcls, name, bases):
866                ns = super().__prepare__(name, bases)
867                ns['BMeta_was_here'] = True
868                return ns
869
870        A = types.new_class("A", (), {"metaclass": AMeta})
871        self.assertEqual(new_calls, ['AMeta'])
872        new_calls.clear()
873
874        B = types.new_class("B", (), {"metaclass": BMeta})
875        # BMeta.__new__ calls AMeta.__new__ with super:
876        self.assertEqual(new_calls, ['BMeta', 'AMeta'])
877        new_calls.clear()
878
879        C = types.new_class("C", (A, B))
880        # The most derived metaclass is BMeta:
881        self.assertEqual(new_calls, ['BMeta', 'AMeta'])
882        new_calls.clear()
883        # BMeta.__prepare__ should've been called:
884        self.assertIn('BMeta_was_here', C.__dict__)
885
886        # The order of the bases shouldn't matter:
887        C2 = types.new_class("C2", (B, A))
888        self.assertEqual(new_calls, ['BMeta', 'AMeta'])
889        new_calls.clear()
890        self.assertIn('BMeta_was_here', C2.__dict__)
891
892        # Check correct metaclass calculation when a metaclass is declared:
893        D = types.new_class("D", (C,), {"metaclass": type})
894        self.assertEqual(new_calls, ['BMeta', 'AMeta'])
895        new_calls.clear()
896        self.assertIn('BMeta_was_here', D.__dict__)
897
898        E = types.new_class("E", (C,), {"metaclass": AMeta})
899        self.assertEqual(new_calls, ['BMeta', 'AMeta'])
900        new_calls.clear()
901        self.assertIn('BMeta_was_here', E.__dict__)
902
903    def test_metaclass_override_function(self):
904        # Special case: the given metaclass isn't a class,
905        # so there is no metaclass calculation.
906        class A(metaclass=self.Meta):
907            pass
908
909        marker = object()
910        def func(*args, **kwargs):
911            return marker
912
913        X = types.new_class("X", (), {"metaclass": func})
914        Y = types.new_class("Y", (object,), {"metaclass": func})
915        Z = types.new_class("Z", (A,), {"metaclass": func})
916        self.assertIs(marker, X)
917        self.assertIs(marker, Y)
918        self.assertIs(marker, Z)
919
920    def test_metaclass_override_callable(self):
921        # The given metaclass is a class,
922        # but not a descendant of type.
923        new_calls = []  # to check the order of __new__ calls
924        prepare_calls = []  # to track __prepare__ calls
925        class ANotMeta:
926            def __new__(mcls, *args, **kwargs):
927                new_calls.append('ANotMeta')
928                return super().__new__(mcls)
929            @classmethod
930            def __prepare__(mcls, name, bases):
931                prepare_calls.append('ANotMeta')
932                return {}
933
934        class BNotMeta(ANotMeta):
935            def __new__(mcls, *args, **kwargs):
936                new_calls.append('BNotMeta')
937                return super().__new__(mcls)
938            @classmethod
939            def __prepare__(mcls, name, bases):
940                prepare_calls.append('BNotMeta')
941                return super().__prepare__(name, bases)
942
943        A = types.new_class("A", (), {"metaclass": ANotMeta})
944        self.assertIs(ANotMeta, type(A))
945        self.assertEqual(prepare_calls, ['ANotMeta'])
946        prepare_calls.clear()
947        self.assertEqual(new_calls, ['ANotMeta'])
948        new_calls.clear()
949
950        B = types.new_class("B", (), {"metaclass": BNotMeta})
951        self.assertIs(BNotMeta, type(B))
952        self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
953        prepare_calls.clear()
954        self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
955        new_calls.clear()
956
957        C = types.new_class("C", (A, B))
958        self.assertIs(BNotMeta, type(C))
959        self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
960        prepare_calls.clear()
961        self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
962        new_calls.clear()
963
964        C2 = types.new_class("C2", (B, A))
965        self.assertIs(BNotMeta, type(C2))
966        self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
967        prepare_calls.clear()
968        self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
969        new_calls.clear()
970
971        # This is a TypeError, because of a metaclass conflict:
972        # BNotMeta is neither a subclass, nor a superclass of type
973        with self.assertRaises(TypeError):
974            D = types.new_class("D", (C,), {"metaclass": type})
975
976        E = types.new_class("E", (C,), {"metaclass": ANotMeta})
977        self.assertIs(BNotMeta, type(E))
978        self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
979        prepare_calls.clear()
980        self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
981        new_calls.clear()
982
983        F = types.new_class("F", (object(), C))
984        self.assertIs(BNotMeta, type(F))
985        self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
986        prepare_calls.clear()
987        self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
988        new_calls.clear()
989
990        F2 = types.new_class("F2", (C, object()))
991        self.assertIs(BNotMeta, type(F2))
992        self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
993        prepare_calls.clear()
994        self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
995        new_calls.clear()
996
997        # TypeError: BNotMeta is neither a
998        # subclass, nor a superclass of int
999        with self.assertRaises(TypeError):
1000            X = types.new_class("X", (C, int()))
1001        with self.assertRaises(TypeError):
1002            X = types.new_class("X", (int(), C))
1003
1004    def test_one_argument_type(self):
1005        expected_message = 'type.__new__() takes exactly 3 arguments (1 given)'
1006
1007        # Only type itself can use the one-argument form (#27157)
1008        self.assertIs(type(5), int)
1009
1010        class M(type):
1011            pass
1012        with self.assertRaises(TypeError) as cm:
1013            M(5)
1014        self.assertEqual(str(cm.exception), expected_message)
1015
1016        class N(type, metaclass=M):
1017            pass
1018        with self.assertRaises(TypeError) as cm:
1019            N(5)
1020        self.assertEqual(str(cm.exception), expected_message)
1021
1022
1023class SimpleNamespaceTests(unittest.TestCase):
1024
1025    def test_constructor(self):
1026        ns1 = types.SimpleNamespace()
1027        ns2 = types.SimpleNamespace(x=1, y=2)
1028        ns3 = types.SimpleNamespace(**dict(x=1, y=2))
1029
1030        with self.assertRaises(TypeError):
1031            types.SimpleNamespace(1, 2, 3)
1032
1033        self.assertEqual(len(ns1.__dict__), 0)
1034        self.assertEqual(vars(ns1), {})
1035        self.assertEqual(len(ns2.__dict__), 2)
1036        self.assertEqual(vars(ns2), {'y': 2, 'x': 1})
1037        self.assertEqual(len(ns3.__dict__), 2)
1038        self.assertEqual(vars(ns3), {'y': 2, 'x': 1})
1039
1040    def test_unbound(self):
1041        ns1 = vars(types.SimpleNamespace())
1042        ns2 = vars(types.SimpleNamespace(x=1, y=2))
1043
1044        self.assertEqual(ns1, {})
1045        self.assertEqual(ns2, {'y': 2, 'x': 1})
1046
1047    def test_underlying_dict(self):
1048        ns1 = types.SimpleNamespace()
1049        ns2 = types.SimpleNamespace(x=1, y=2)
1050        ns3 = types.SimpleNamespace(a=True, b=False)
1051        mapping = ns3.__dict__
1052        del ns3
1053
1054        self.assertEqual(ns1.__dict__, {})
1055        self.assertEqual(ns2.__dict__, {'y': 2, 'x': 1})
1056        self.assertEqual(mapping, dict(a=True, b=False))
1057
1058    def test_attrget(self):
1059        ns = types.SimpleNamespace(x=1, y=2, w=3)
1060
1061        self.assertEqual(ns.x, 1)
1062        self.assertEqual(ns.y, 2)
1063        self.assertEqual(ns.w, 3)
1064        with self.assertRaises(AttributeError):
1065            ns.z
1066
1067    def test_attrset(self):
1068        ns1 = types.SimpleNamespace()
1069        ns2 = types.SimpleNamespace(x=1, y=2, w=3)
1070        ns1.a = 'spam'
1071        ns1.b = 'ham'
1072        ns2.z = 4
1073        ns2.theta = None
1074
1075        self.assertEqual(ns1.__dict__, dict(a='spam', b='ham'))
1076        self.assertEqual(ns2.__dict__, dict(x=1, y=2, w=3, z=4, theta=None))
1077
1078    def test_attrdel(self):
1079        ns1 = types.SimpleNamespace()
1080        ns2 = types.SimpleNamespace(x=1, y=2, w=3)
1081
1082        with self.assertRaises(AttributeError):
1083            del ns1.spam
1084        with self.assertRaises(AttributeError):
1085            del ns2.spam
1086
1087        del ns2.y
1088        self.assertEqual(vars(ns2), dict(w=3, x=1))
1089        ns2.y = 'spam'
1090        self.assertEqual(vars(ns2), dict(w=3, x=1, y='spam'))
1091        del ns2.y
1092        self.assertEqual(vars(ns2), dict(w=3, x=1))
1093
1094        ns1.spam = 5
1095        self.assertEqual(vars(ns1), dict(spam=5))
1096        del ns1.spam
1097        self.assertEqual(vars(ns1), {})
1098
1099    def test_repr(self):
1100        ns1 = types.SimpleNamespace(x=1, y=2, w=3)
1101        ns2 = types.SimpleNamespace()
1102        ns2.x = "spam"
1103        ns2._y = 5
1104        name = "namespace"
1105
1106        self.assertEqual(repr(ns1), "{name}(w=3, x=1, y=2)".format(name=name))
1107        self.assertEqual(repr(ns2), "{name}(_y=5, x='spam')".format(name=name))
1108
1109    def test_equal(self):
1110        ns1 = types.SimpleNamespace(x=1)
1111        ns2 = types.SimpleNamespace()
1112        ns2.x = 1
1113
1114        self.assertEqual(types.SimpleNamespace(), types.SimpleNamespace())
1115        self.assertEqual(ns1, ns2)
1116        self.assertNotEqual(ns2, types.SimpleNamespace())
1117
1118    def test_nested(self):
1119        ns1 = types.SimpleNamespace(a=1, b=2)
1120        ns2 = types.SimpleNamespace()
1121        ns3 = types.SimpleNamespace(x=ns1)
1122        ns2.spam = ns1
1123        ns2.ham = '?'
1124        ns2.spam = ns3
1125
1126        self.assertEqual(vars(ns1), dict(a=1, b=2))
1127        self.assertEqual(vars(ns2), dict(spam=ns3, ham='?'))
1128        self.assertEqual(ns2.spam, ns3)
1129        self.assertEqual(vars(ns3), dict(x=ns1))
1130        self.assertEqual(ns3.x.a, 1)
1131
1132    def test_recursive(self):
1133        ns1 = types.SimpleNamespace(c='cookie')
1134        ns2 = types.SimpleNamespace()
1135        ns3 = types.SimpleNamespace(x=1)
1136        ns1.spam = ns1
1137        ns2.spam = ns3
1138        ns3.spam = ns2
1139
1140        self.assertEqual(ns1.spam, ns1)
1141        self.assertEqual(ns1.spam.spam, ns1)
1142        self.assertEqual(ns1.spam.spam, ns1.spam)
1143        self.assertEqual(ns2.spam, ns3)
1144        self.assertEqual(ns3.spam, ns2)
1145        self.assertEqual(ns2.spam.spam, ns2)
1146
1147    def test_recursive_repr(self):
1148        ns1 = types.SimpleNamespace(c='cookie')
1149        ns2 = types.SimpleNamespace()
1150        ns3 = types.SimpleNamespace(x=1)
1151        ns1.spam = ns1
1152        ns2.spam = ns3
1153        ns3.spam = ns2
1154        name = "namespace"
1155        repr1 = "{name}(c='cookie', spam={name}(...))".format(name=name)
1156        repr2 = "{name}(spam={name}(spam={name}(...), x=1))".format(name=name)
1157
1158        self.assertEqual(repr(ns1), repr1)
1159        self.assertEqual(repr(ns2), repr2)
1160
1161    def test_as_dict(self):
1162        ns = types.SimpleNamespace(spam='spamspamspam')
1163
1164        with self.assertRaises(TypeError):
1165            len(ns)
1166        with self.assertRaises(TypeError):
1167            iter(ns)
1168        with self.assertRaises(TypeError):
1169            'spam' in ns
1170        with self.assertRaises(TypeError):
1171            ns['spam']
1172
1173    def test_subclass(self):
1174        class Spam(types.SimpleNamespace):
1175            pass
1176
1177        spam = Spam(ham=8, eggs=9)
1178
1179        self.assertIs(type(spam), Spam)
1180        self.assertEqual(vars(spam), {'ham': 8, 'eggs': 9})
1181
1182    def test_pickle(self):
1183        ns = types.SimpleNamespace(breakfast="spam", lunch="spam")
1184
1185        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1186            pname = "protocol {}".format(protocol)
1187            try:
1188                ns_pickled = pickle.dumps(ns, protocol)
1189            except TypeError as e:
1190                raise TypeError(pname) from e
1191            ns_roundtrip = pickle.loads(ns_pickled)
1192
1193            self.assertEqual(ns, ns_roundtrip, pname)
1194
1195    def test_fake_namespace_compare(self):
1196        # Issue #24257: Incorrect use of PyObject_IsInstance() caused
1197        # SystemError.
1198        class FakeSimpleNamespace(str):
1199            __class__ = types.SimpleNamespace
1200        self.assertFalse(types.SimpleNamespace() == FakeSimpleNamespace())
1201        self.assertTrue(types.SimpleNamespace() != FakeSimpleNamespace())
1202        with self.assertRaises(TypeError):
1203            types.SimpleNamespace() < FakeSimpleNamespace()
1204        with self.assertRaises(TypeError):
1205            types.SimpleNamespace() <= FakeSimpleNamespace()
1206        with self.assertRaises(TypeError):
1207            types.SimpleNamespace() > FakeSimpleNamespace()
1208        with self.assertRaises(TypeError):
1209            types.SimpleNamespace() >= FakeSimpleNamespace()
1210
1211
1212class CoroutineTests(unittest.TestCase):
1213    def test_wrong_args(self):
1214        samples = [None, 1, object()]
1215        for sample in samples:
1216            with self.assertRaisesRegex(TypeError,
1217                                        'types.coroutine.*expects a callable'):
1218                types.coroutine(sample)
1219
1220    def test_non_gen_values(self):
1221        @types.coroutine
1222        def foo():
1223            return 'spam'
1224        self.assertEqual(foo(), 'spam')
1225
1226        class Awaitable:
1227            def __await__(self):
1228                return ()
1229        aw = Awaitable()
1230        @types.coroutine
1231        def foo():
1232            return aw
1233        self.assertIs(aw, foo())
1234
1235        # decorate foo second time
1236        foo = types.coroutine(foo)
1237        self.assertIs(aw, foo())
1238
1239    def test_async_def(self):
1240        # Test that types.coroutine passes 'async def' coroutines
1241        # without modification
1242
1243        async def foo(): pass
1244        foo_code = foo.__code__
1245        foo_flags = foo.__code__.co_flags
1246        decorated_foo = types.coroutine(foo)
1247        self.assertIs(foo, decorated_foo)
1248        self.assertEqual(foo.__code__.co_flags, foo_flags)
1249        self.assertIs(decorated_foo.__code__, foo_code)
1250
1251        foo_coro = foo()
1252        def bar(): return foo_coro
1253        for _ in range(2):
1254            bar = types.coroutine(bar)
1255            coro = bar()
1256            self.assertIs(foo_coro, coro)
1257            self.assertEqual(coro.cr_code.co_flags, foo_flags)
1258            coro.close()
1259
1260    def test_duck_coro(self):
1261        class CoroLike:
1262            def send(self): pass
1263            def throw(self): pass
1264            def close(self): pass
1265            def __await__(self): return self
1266
1267        coro = CoroLike()
1268        @types.coroutine
1269        def foo():
1270            return coro
1271        self.assertIs(foo(), coro)
1272        self.assertIs(foo().__await__(), coro)
1273
1274    def test_duck_corogen(self):
1275        class CoroGenLike:
1276            def send(self): pass
1277            def throw(self): pass
1278            def close(self): pass
1279            def __await__(self): return self
1280            def __iter__(self): return self
1281            def __next__(self): pass
1282
1283        coro = CoroGenLike()
1284        @types.coroutine
1285        def foo():
1286            return coro
1287        self.assertIs(foo(), coro)
1288        self.assertIs(foo().__await__(), coro)
1289
1290    def test_duck_gen(self):
1291        class GenLike:
1292            def send(self): pass
1293            def throw(self): pass
1294            def close(self): pass
1295            def __iter__(self): pass
1296            def __next__(self): pass
1297
1298        # Setup generator mock object
1299        gen = unittest.mock.MagicMock(GenLike)
1300        gen.__iter__ = lambda gen: gen
1301        gen.__name__ = 'gen'
1302        gen.__qualname__ = 'test.gen'
1303        self.assertIsInstance(gen, collections.abc.Generator)
1304        self.assertIs(gen, iter(gen))
1305
1306        @types.coroutine
1307        def foo(): return gen
1308
1309        wrapper = foo()
1310        self.assertIsInstance(wrapper, types._GeneratorWrapper)
1311        self.assertIs(wrapper.__await__(), wrapper)
1312        # Wrapper proxies duck generators completely:
1313        self.assertIs(iter(wrapper), wrapper)
1314
1315        self.assertIsInstance(wrapper, collections.abc.Coroutine)
1316        self.assertIsInstance(wrapper, collections.abc.Awaitable)
1317
1318        self.assertIs(wrapper.__qualname__, gen.__qualname__)
1319        self.assertIs(wrapper.__name__, gen.__name__)
1320
1321        # Test AttributeErrors
1322        for name in {'gi_running', 'gi_frame', 'gi_code', 'gi_yieldfrom',
1323                     'cr_running', 'cr_frame', 'cr_code', 'cr_await'}:
1324            with self.assertRaises(AttributeError):
1325                getattr(wrapper, name)
1326
1327        # Test attributes pass-through
1328        gen.gi_running = object()
1329        gen.gi_frame = object()
1330        gen.gi_code = object()
1331        gen.gi_yieldfrom = object()
1332        self.assertIs(wrapper.gi_running, gen.gi_running)
1333        self.assertIs(wrapper.gi_frame, gen.gi_frame)
1334        self.assertIs(wrapper.gi_code, gen.gi_code)
1335        self.assertIs(wrapper.gi_yieldfrom, gen.gi_yieldfrom)
1336        self.assertIs(wrapper.cr_running, gen.gi_running)
1337        self.assertIs(wrapper.cr_frame, gen.gi_frame)
1338        self.assertIs(wrapper.cr_code, gen.gi_code)
1339        self.assertIs(wrapper.cr_await, gen.gi_yieldfrom)
1340
1341        wrapper.close()
1342        gen.close.assert_called_once_with()
1343
1344        wrapper.send(1)
1345        gen.send.assert_called_once_with(1)
1346        gen.reset_mock()
1347
1348        next(wrapper)
1349        gen.__next__.assert_called_once_with()
1350        gen.reset_mock()
1351
1352        wrapper.throw(1, 2, 3)
1353        gen.throw.assert_called_once_with(1, 2, 3)
1354        gen.reset_mock()
1355
1356        wrapper.throw(1, 2)
1357        gen.throw.assert_called_once_with(1, 2)
1358        gen.reset_mock()
1359
1360        wrapper.throw(1)
1361        gen.throw.assert_called_once_with(1)
1362        gen.reset_mock()
1363
1364        # Test exceptions propagation
1365        error = Exception()
1366        gen.throw.side_effect = error
1367        try:
1368            wrapper.throw(1)
1369        except Exception as ex:
1370            self.assertIs(ex, error)
1371        else:
1372            self.fail('wrapper did not propagate an exception')
1373
1374        # Test invalid args
1375        gen.reset_mock()
1376        with self.assertRaises(TypeError):
1377            wrapper.throw()
1378        self.assertFalse(gen.throw.called)
1379        with self.assertRaises(TypeError):
1380            wrapper.close(1)
1381        self.assertFalse(gen.close.called)
1382        with self.assertRaises(TypeError):
1383            wrapper.send()
1384        self.assertFalse(gen.send.called)
1385
1386        # Test that we do not double wrap
1387        @types.coroutine
1388        def bar(): return wrapper
1389        self.assertIs(wrapper, bar())
1390
1391        # Test weakrefs support
1392        ref = weakref.ref(wrapper)
1393        self.assertIs(ref(), wrapper)
1394
1395    def test_duck_functional_gen(self):
1396        class Generator:
1397            """Emulates the following generator (very clumsy):
1398
1399              def gen(fut):
1400                  result = yield fut
1401                  return result * 2
1402            """
1403            def __init__(self, fut):
1404                self._i = 0
1405                self._fut = fut
1406            def __iter__(self):
1407                return self
1408            def __next__(self):
1409                return self.send(None)
1410            def send(self, v):
1411                try:
1412                    if self._i == 0:
1413                        assert v is None
1414                        return self._fut
1415                    if self._i == 1:
1416                        raise StopIteration(v * 2)
1417                    if self._i > 1:
1418                        raise StopIteration
1419                finally:
1420                    self._i += 1
1421            def throw(self, tp, *exc):
1422                self._i = 100
1423                if tp is not GeneratorExit:
1424                    raise tp
1425            def close(self):
1426                self.throw(GeneratorExit)
1427
1428        @types.coroutine
1429        def foo(): return Generator('spam')
1430
1431        wrapper = foo()
1432        self.assertIsInstance(wrapper, types._GeneratorWrapper)
1433
1434        async def corofunc():
1435            return await foo() + 100
1436        coro = corofunc()
1437
1438        self.assertEqual(coro.send(None), 'spam')
1439        try:
1440            coro.send(20)
1441        except StopIteration as ex:
1442            self.assertEqual(ex.args[0], 140)
1443        else:
1444            self.fail('StopIteration was expected')
1445
1446    def test_gen(self):
1447        def gen_func():
1448            yield 1
1449            return (yield 2)
1450        gen = gen_func()
1451        @types.coroutine
1452        def foo(): return gen
1453        wrapper = foo()
1454        self.assertIsInstance(wrapper, types._GeneratorWrapper)
1455        self.assertIs(wrapper.__await__(), gen)
1456
1457        for name in ('__name__', '__qualname__', 'gi_code',
1458                     'gi_running', 'gi_frame'):
1459            self.assertIs(getattr(foo(), name),
1460                          getattr(gen, name))
1461        self.assertIs(foo().cr_code, gen.gi_code)
1462
1463        self.assertEqual(next(wrapper), 1)
1464        self.assertEqual(wrapper.send(None), 2)
1465        with self.assertRaisesRegex(StopIteration, 'spam'):
1466            wrapper.send('spam')
1467
1468        gen = gen_func()
1469        wrapper = foo()
1470        wrapper.send(None)
1471        with self.assertRaisesRegex(Exception, 'ham'):
1472            wrapper.throw(Exception, Exception('ham'))
1473
1474        # decorate foo second time
1475        foo = types.coroutine(foo)
1476        self.assertIs(foo().__await__(), gen)
1477
1478    def test_returning_itercoro(self):
1479        @types.coroutine
1480        def gen():
1481            yield
1482
1483        gencoro = gen()
1484
1485        @types.coroutine
1486        def foo():
1487            return gencoro
1488
1489        self.assertIs(foo(), gencoro)
1490
1491        # decorate foo second time
1492        foo = types.coroutine(foo)
1493        self.assertIs(foo(), gencoro)
1494
1495    def test_genfunc(self):
1496        def gen(): yield
1497        self.assertIs(types.coroutine(gen), gen)
1498        self.assertIs(types.coroutine(types.coroutine(gen)), gen)
1499
1500        self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
1501        self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
1502
1503        g = gen()
1504        self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE)
1505        self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE)
1506
1507        self.assertIs(types.coroutine(gen), gen)
1508
1509    def test_wrapper_object(self):
1510        def gen():
1511            yield
1512        @types.coroutine
1513        def coro():
1514            return gen()
1515
1516        wrapper = coro()
1517        self.assertIn('GeneratorWrapper', repr(wrapper))
1518        self.assertEqual(repr(wrapper), str(wrapper))
1519        self.assertTrue(set(dir(wrapper)).issuperset({
1520            '__await__', '__iter__', '__next__', 'cr_code', 'cr_running',
1521            'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send',
1522            'close', 'throw'}))
1523
1524
1525if __name__ == '__main__':
1526    unittest.main()
1527