1import contextlib
2import copy
3import inspect
4import pickle
5import sys
6import types
7import unittest
8import warnings
9from test import support
10
11
12class AsyncYieldFrom:
13    def __init__(self, obj):
14        self.obj = obj
15
16    def __await__(self):
17        yield from self.obj
18
19
20class AsyncYield:
21    def __init__(self, value):
22        self.value = value
23
24    def __await__(self):
25        yield self.value
26
27
28def run_async(coro):
29    assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
30
31    buffer = []
32    result = None
33    while True:
34        try:
35            buffer.append(coro.send(None))
36        except StopIteration as ex:
37            result = ex.args[0] if ex.args else None
38            break
39    return buffer, result
40
41
42def run_async__await__(coro):
43    assert coro.__class__ is types.CoroutineType
44    aw = coro.__await__()
45    buffer = []
46    result = None
47    i = 0
48    while True:
49        try:
50            if i % 2:
51                buffer.append(next(aw))
52            else:
53                buffer.append(aw.send(None))
54            i += 1
55        except StopIteration as ex:
56            result = ex.args[0] if ex.args else None
57            break
58    return buffer, result
59
60
61@contextlib.contextmanager
62def silence_coro_gc():
63    with warnings.catch_warnings():
64        warnings.simplefilter("ignore")
65        yield
66        support.gc_collect()
67
68
69class AsyncBadSyntaxTest(unittest.TestCase):
70
71    def test_badsyntax_1(self):
72        samples = [
73            """def foo():
74                await something()
75            """,
76
77            """await something()""",
78
79            """async def foo():
80                yield from []
81            """,
82
83            """async def foo():
84                await await fut
85            """,
86
87            """async def foo(a=await something()):
88                pass
89            """,
90
91            """async def foo(a:await something()):
92                pass
93            """,
94
95            """async def foo():
96                def bar():
97                 [i async for i in els]
98            """,
99
100            """async def foo():
101                def bar():
102                 [await i for i in els]
103            """,
104
105            """async def foo():
106                def bar():
107                 [i for i in els
108                    async for b in els]
109            """,
110
111            """async def foo():
112                def bar():
113                 [i for i in els
114                    for c in b
115                    async for b in els]
116            """,
117
118            """async def foo():
119                def bar():
120                 [i for i in els
121                    async for b in els
122                    for c in b]
123            """,
124
125            """async def foo():
126                def bar():
127                 [i for i in els
128                    for b in await els]
129            """,
130
131            """async def foo():
132                def bar():
133                 [i for i in els
134                    for b in els
135                        if await b]
136            """,
137
138            """async def foo():
139                def bar():
140                 [i for i in await els]
141            """,
142
143            """async def foo():
144                def bar():
145                 [i for i in els if await i]
146            """,
147
148            """def bar():
149                 [i async for i in els]
150            """,
151
152            """def bar():
153                 [await i for i in els]
154            """,
155
156            """def bar():
157                 [i for i in els
158                    async for b in els]
159            """,
160
161            """def bar():
162                 [i for i in els
163                    for c in b
164                    async for b in els]
165            """,
166
167            """def bar():
168                 [i for i in els
169                    async for b in els
170                    for c in b]
171            """,
172
173            """def bar():
174                 [i for i in els
175                    for b in await els]
176            """,
177
178            """def bar():
179                 [i for i in els
180                    for b in els
181                        if await b]
182            """,
183
184            """def bar():
185                 [i for i in await els]
186            """,
187
188            """def bar():
189                 [i for i in els if await i]
190            """,
191
192            """async def foo():
193                await
194            """,
195
196            """async def foo():
197                   def bar(): pass
198                   await = 1
199            """,
200
201            """async def foo():
202
203                   def bar(): pass
204                   await = 1
205            """,
206
207            """async def foo():
208                   def bar(): pass
209                   if 1:
210                       await = 1
211            """,
212
213            """def foo():
214                   async def bar(): pass
215                   if 1:
216                       await a
217            """,
218
219            """def foo():
220                   async def bar(): pass
221                   await a
222            """,
223
224            """def foo():
225                   def baz(): pass
226                   async def bar(): pass
227                   await a
228            """,
229
230            """def foo():
231                   def baz(): pass
232                   # 456
233                   async def bar(): pass
234                   # 123
235                   await a
236            """,
237
238            """async def foo():
239                   def baz(): pass
240                   # 456
241                   async def bar(): pass
242                   # 123
243                   await = 2
244            """,
245
246            """def foo():
247
248                   def baz(): pass
249
250                   async def bar(): pass
251
252                   await a
253            """,
254
255            """async def foo():
256
257                   def baz(): pass
258
259                   async def bar(): pass
260
261                   await = 2
262            """,
263
264            """async def foo():
265                   def async(): pass
266            """,
267
268            """async def foo():
269                   def await(): pass
270            """,
271
272            """async def foo():
273                   def bar():
274                       await
275            """,
276
277            """async def foo():
278                   return lambda async: await
279            """,
280
281            """async def foo():
282                   return lambda a: await
283            """,
284
285            """await a()""",
286
287            """async def foo(a=await b):
288                   pass
289            """,
290
291            """async def foo(a:await b):
292                   pass
293            """,
294
295            """def baz():
296                   async def foo(a=await b):
297                       pass
298            """,
299
300            """async def foo(async):
301                   pass
302            """,
303
304            """async def foo():
305                   def bar():
306                        def baz():
307                            async = 1
308            """,
309
310            """async def foo():
311                   def bar():
312                        def baz():
313                            pass
314                        async = 1
315            """,
316
317            """def foo():
318                   async def bar():
319
320                        async def baz():
321                            pass
322
323                        def baz():
324                            42
325
326                        async = 1
327            """,
328
329            """async def foo():
330                   def bar():
331                        def baz():
332                            pass\nawait foo()
333            """,
334
335            """def foo():
336                   def bar():
337                        async def baz():
338                            pass\nawait foo()
339            """,
340
341            """async def foo(await):
342                   pass
343            """,
344
345            """def foo():
346
347                   async def bar(): pass
348
349                   await a
350            """,
351
352            """def foo():
353                   async def bar():
354                        pass\nawait a
355            """]
356
357        for code in samples:
358            with self.subTest(code=code), self.assertRaises(SyntaxError):
359                compile(code, "<test>", "exec")
360
361    def test_badsyntax_2(self):
362        samples = [
363            """def foo():
364                await = 1
365            """,
366
367            """class Bar:
368                def async(): pass
369            """,
370
371            """class Bar:
372                async = 1
373            """,
374
375            """class async:
376                pass
377            """,
378
379            """class await:
380                pass
381            """,
382
383            """import math as await""",
384
385            """def async():
386                pass""",
387
388            """def foo(*, await=1):
389                pass"""
390
391            """async = 1""",
392
393            """print(await=1)"""
394        ]
395
396        for code in samples:
397            with self.subTest(code=code), self.assertWarnsRegex(
398                    DeprecationWarning,
399                    "'await' will become reserved keywords"):
400                compile(code, "<test>", "exec")
401
402    def test_badsyntax_3(self):
403        with self.assertRaises(DeprecationWarning):
404            with warnings.catch_warnings():
405                warnings.simplefilter("error")
406                compile("async = 1", "<test>", "exec")
407
408    def test_goodsyntax_1(self):
409        # Tests for issue 24619
410
411        samples = [
412            '''def foo(await):
413                async def foo(): pass
414                async def foo():
415                    pass
416                return await + 1
417            ''',
418
419            '''def foo(await):
420                async def foo(): pass
421                async def foo(): pass
422                return await + 1
423            ''',
424
425            '''def foo(await):
426
427                async def foo(): pass
428
429                async def foo(): pass
430
431                return await + 1
432            ''',
433
434            '''def foo(await):
435                """spam"""
436                async def foo(): \
437                    pass
438                # 123
439                async def foo(): pass
440                # 456
441                return await + 1
442            ''',
443
444            '''def foo(await):
445                def foo(): pass
446                def foo(): pass
447                async def bar(): return await_
448                await_ = await
449                try:
450                    bar().send(None)
451                except StopIteration as ex:
452                    return ex.args[0] + 1
453            '''
454        ]
455
456        for code in samples:
457            with self.subTest(code=code):
458                loc = {}
459
460                with warnings.catch_warnings():
461                    warnings.simplefilter("ignore")
462                    exec(code, loc, loc)
463
464                self.assertEqual(loc['foo'](10), 11)
465
466
467class TokenizerRegrTest(unittest.TestCase):
468
469    def test_oneline_defs(self):
470        buf = []
471        for i in range(500):
472            buf.append('def i{i}(): return {i}'.format(i=i))
473        buf = '\n'.join(buf)
474
475        # Test that 500 consequent, one-line defs is OK
476        ns = {}
477        exec(buf, ns, ns)
478        self.assertEqual(ns['i499'](), 499)
479
480        # Test that 500 consequent, one-line defs *and*
481        # one 'async def' following them is OK
482        buf += '\nasync def foo():\n    return'
483        ns = {}
484        exec(buf, ns, ns)
485        self.assertEqual(ns['i499'](), 499)
486        self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
487
488
489class CoroutineTest(unittest.TestCase):
490
491    def test_gen_1(self):
492        def gen(): yield
493        self.assertFalse(hasattr(gen, '__await__'))
494
495    def test_func_1(self):
496        async def foo():
497            return 10
498
499        f = foo()
500        self.assertIsInstance(f, types.CoroutineType)
501        self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
502        self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
503        self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
504        self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
505        self.assertEqual(run_async(f), ([], 10))
506
507        self.assertEqual(run_async__await__(foo()), ([], 10))
508
509        def bar(): pass
510        self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
511
512    def test_func_2(self):
513        async def foo():
514            raise StopIteration
515
516        with self.assertRaisesRegex(
517                RuntimeError, "coroutine raised StopIteration"):
518
519            run_async(foo())
520
521    def test_func_3(self):
522        async def foo():
523            raise StopIteration
524
525        with silence_coro_gc():
526            self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
527
528    def test_func_4(self):
529        async def foo():
530            raise StopIteration
531
532        check = lambda: self.assertRaisesRegex(
533            TypeError, "'coroutine' object is not iterable")
534
535        with check():
536            list(foo())
537
538        with check():
539            tuple(foo())
540
541        with check():
542            sum(foo())
543
544        with check():
545            iter(foo())
546
547        with silence_coro_gc(), check():
548            for i in foo():
549                pass
550
551        with silence_coro_gc(), check():
552            [i for i in foo()]
553
554    def test_func_5(self):
555        @types.coroutine
556        def bar():
557            yield 1
558
559        async def foo():
560            await bar()
561
562        check = lambda: self.assertRaisesRegex(
563            TypeError, "'coroutine' object is not iterable")
564
565        with check():
566            for el in foo(): pass
567
568        # the following should pass without an error
569        for el in bar():
570            self.assertEqual(el, 1)
571        self.assertEqual([el for el in bar()], [1])
572        self.assertEqual(tuple(bar()), (1,))
573        self.assertEqual(next(iter(bar())), 1)
574
575    def test_func_6(self):
576        @types.coroutine
577        def bar():
578            yield 1
579            yield 2
580
581        async def foo():
582            await bar()
583
584        f = foo()
585        self.assertEqual(f.send(None), 1)
586        self.assertEqual(f.send(None), 2)
587        with self.assertRaises(StopIteration):
588            f.send(None)
589
590    def test_func_7(self):
591        async def bar():
592            return 10
593
594        def foo():
595            yield from bar()
596
597        with silence_coro_gc(), self.assertRaisesRegex(
598            TypeError,
599            "cannot 'yield from' a coroutine object in a non-coroutine generator"):
600
601            list(foo())
602
603    def test_func_8(self):
604        @types.coroutine
605        def bar():
606            return (yield from foo())
607
608        async def foo():
609            return 'spam'
610
611        self.assertEqual(run_async(bar()), ([], 'spam') )
612
613    def test_func_9(self):
614        async def foo(): pass
615
616        with self.assertWarnsRegex(
617            RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
618
619            foo()
620            support.gc_collect()
621
622    def test_func_10(self):
623        N = 0
624
625        @types.coroutine
626        def gen():
627            nonlocal N
628            try:
629                a = yield
630                yield (a ** 2)
631            except ZeroDivisionError:
632                N += 100
633                raise
634            finally:
635                N += 1
636
637        async def foo():
638            await gen()
639
640        coro = foo()
641        aw = coro.__await__()
642        self.assertIs(aw, iter(aw))
643        next(aw)
644        self.assertEqual(aw.send(10), 100)
645
646        self.assertEqual(N, 0)
647        aw.close()
648        self.assertEqual(N, 1)
649
650        coro = foo()
651        aw = coro.__await__()
652        next(aw)
653        with self.assertRaises(ZeroDivisionError):
654            aw.throw(ZeroDivisionError, None, None)
655        self.assertEqual(N, 102)
656
657    def test_func_11(self):
658        async def func(): pass
659        coro = func()
660        # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
661        # initialized
662        self.assertIn('__await__', dir(coro))
663        self.assertIn('__iter__', dir(coro.__await__()))
664        self.assertIn('coroutine_wrapper', repr(coro.__await__()))
665        coro.close() # avoid RuntimeWarning
666
667    def test_func_12(self):
668        async def g():
669            i = me.send(None)
670            await foo
671        me = g()
672        with self.assertRaisesRegex(ValueError,
673                                    "coroutine already executing"):
674            me.send(None)
675
676    def test_func_13(self):
677        async def g():
678            pass
679        with self.assertRaisesRegex(
680            TypeError,
681            "can't send non-None value to a just-started coroutine"):
682
683            g().send('spam')
684
685    def test_func_14(self):
686        @types.coroutine
687        def gen():
688            yield
689        async def coro():
690            try:
691                await gen()
692            except GeneratorExit:
693                await gen()
694        c = coro()
695        c.send(None)
696        with self.assertRaisesRegex(RuntimeError,
697                                    "coroutine ignored GeneratorExit"):
698            c.close()
699
700    def test_func_15(self):
701        # See http://bugs.python.org/issue25887 for details
702
703        async def spammer():
704            return 'spam'
705        async def reader(coro):
706            return await coro
707
708        spammer_coro = spammer()
709
710        with self.assertRaisesRegex(StopIteration, 'spam'):
711            reader(spammer_coro).send(None)
712
713        with self.assertRaisesRegex(RuntimeError,
714                                    'cannot reuse already awaited coroutine'):
715            reader(spammer_coro).send(None)
716
717    def test_func_16(self):
718        # See http://bugs.python.org/issue25887 for details
719
720        @types.coroutine
721        def nop():
722            yield
723        async def send():
724            await nop()
725            return 'spam'
726        async def read(coro):
727            await nop()
728            return await coro
729
730        spammer = send()
731
732        reader = read(spammer)
733        reader.send(None)
734        reader.send(None)
735        with self.assertRaisesRegex(Exception, 'ham'):
736            reader.throw(Exception('ham'))
737
738        reader = read(spammer)
739        reader.send(None)
740        with self.assertRaisesRegex(RuntimeError,
741                                    'cannot reuse already awaited coroutine'):
742            reader.send(None)
743
744        with self.assertRaisesRegex(RuntimeError,
745                                    'cannot reuse already awaited coroutine'):
746            reader.throw(Exception('wat'))
747
748    def test_func_17(self):
749        # See http://bugs.python.org/issue25887 for details
750
751        async def coroutine():
752            return 'spam'
753
754        coro = coroutine()
755        with self.assertRaisesRegex(StopIteration, 'spam'):
756            coro.send(None)
757
758        with self.assertRaisesRegex(RuntimeError,
759                                    'cannot reuse already awaited coroutine'):
760            coro.send(None)
761
762        with self.assertRaisesRegex(RuntimeError,
763                                    'cannot reuse already awaited coroutine'):
764            coro.throw(Exception('wat'))
765
766        # Closing a coroutine shouldn't raise any exception even if it's
767        # already closed/exhausted (similar to generators)
768        coro.close()
769        coro.close()
770
771    def test_func_18(self):
772        # See http://bugs.python.org/issue25887 for details
773
774        async def coroutine():
775            return 'spam'
776
777        coro = coroutine()
778        await_iter = coro.__await__()
779        it = iter(await_iter)
780
781        with self.assertRaisesRegex(StopIteration, 'spam'):
782            it.send(None)
783
784        with self.assertRaisesRegex(RuntimeError,
785                                    'cannot reuse already awaited coroutine'):
786            it.send(None)
787
788        with self.assertRaisesRegex(RuntimeError,
789                                    'cannot reuse already awaited coroutine'):
790            # Although the iterator protocol requires iterators to
791            # raise another StopIteration here, we don't want to do
792            # that.  In this particular case, the iterator will raise
793            # a RuntimeError, so that 'yield from' and 'await'
794            # expressions will trigger the error, instead of silently
795            # ignoring the call.
796            next(it)
797
798        with self.assertRaisesRegex(RuntimeError,
799                                    'cannot reuse already awaited coroutine'):
800            it.throw(Exception('wat'))
801
802        with self.assertRaisesRegex(RuntimeError,
803                                    'cannot reuse already awaited coroutine'):
804            it.throw(Exception('wat'))
805
806        # Closing a coroutine shouldn't raise any exception even if it's
807        # already closed/exhausted (similar to generators)
808        it.close()
809        it.close()
810
811    def test_func_19(self):
812        CHK = 0
813
814        @types.coroutine
815        def foo():
816            nonlocal CHK
817            yield
818            try:
819                yield
820            except GeneratorExit:
821                CHK += 1
822
823        async def coroutine():
824            await foo()
825
826        coro = coroutine()
827
828        coro.send(None)
829        coro.send(None)
830
831        self.assertEqual(CHK, 0)
832        coro.close()
833        self.assertEqual(CHK, 1)
834
835        for _ in range(3):
836            # Closing a coroutine shouldn't raise any exception even if it's
837            # already closed/exhausted (similar to generators)
838            coro.close()
839            self.assertEqual(CHK, 1)
840
841    def test_coro_wrapper_send_tuple(self):
842        async def foo():
843            return (10,)
844
845        result = run_async__await__(foo())
846        self.assertEqual(result, ([], (10,)))
847
848    def test_coro_wrapper_send_stop_iterator(self):
849        async def foo():
850            return StopIteration(10)
851
852        result = run_async__await__(foo())
853        self.assertIsInstance(result[1], StopIteration)
854        self.assertEqual(result[1].value, 10)
855
856    def test_cr_await(self):
857        @types.coroutine
858        def a():
859            self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
860            self.assertIsNone(coro_b.cr_await)
861            yield
862            self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
863            self.assertIsNone(coro_b.cr_await)
864
865        async def c():
866            await a()
867
868        async def b():
869            self.assertIsNone(coro_b.cr_await)
870            await c()
871            self.assertIsNone(coro_b.cr_await)
872
873        coro_b = b()
874        self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
875        self.assertIsNone(coro_b.cr_await)
876
877        coro_b.send(None)
878        self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
879        self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
880
881        with self.assertRaises(StopIteration):
882            coro_b.send(None)  # complete coroutine
883        self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
884        self.assertIsNone(coro_b.cr_await)
885
886    def test_corotype_1(self):
887        ct = types.CoroutineType
888        self.assertIn('into coroutine', ct.send.__doc__)
889        self.assertIn('inside coroutine', ct.close.__doc__)
890        self.assertIn('in coroutine', ct.throw.__doc__)
891        self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
892        self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
893        self.assertEqual(ct.__name__, 'coroutine')
894
895        async def f(): pass
896        c = f()
897        self.assertIn('coroutine object', repr(c))
898        c.close()
899
900    def test_await_1(self):
901
902        async def foo():
903            await 1
904        with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
905            run_async(foo())
906
907    def test_await_2(self):
908        async def foo():
909            await []
910        with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
911            run_async(foo())
912
913    def test_await_3(self):
914        async def foo():
915            await AsyncYieldFrom([1, 2, 3])
916
917        self.assertEqual(run_async(foo()), ([1, 2, 3], None))
918        self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
919
920    def test_await_4(self):
921        async def bar():
922            return 42
923
924        async def foo():
925            return await bar()
926
927        self.assertEqual(run_async(foo()), ([], 42))
928
929    def test_await_5(self):
930        class Awaitable:
931            def __await__(self):
932                return
933
934        async def foo():
935            return (await Awaitable())
936
937        with self.assertRaisesRegex(
938            TypeError, "__await__.*returned non-iterator of type"):
939
940            run_async(foo())
941
942    def test_await_6(self):
943        class Awaitable:
944            def __await__(self):
945                return iter([52])
946
947        async def foo():
948            return (await Awaitable())
949
950        self.assertEqual(run_async(foo()), ([52], None))
951
952    def test_await_7(self):
953        class Awaitable:
954            def __await__(self):
955                yield 42
956                return 100
957
958        async def foo():
959            return (await Awaitable())
960
961        self.assertEqual(run_async(foo()), ([42], 100))
962
963    def test_await_8(self):
964        class Awaitable:
965            pass
966
967        async def foo(): return await Awaitable()
968
969        with self.assertRaisesRegex(
970            TypeError, "object Awaitable can't be used in 'await' expression"):
971
972            run_async(foo())
973
974    def test_await_9(self):
975        def wrap():
976            return bar
977
978        async def bar():
979            return 42
980
981        async def foo():
982            b = bar()
983
984            db = {'b':  lambda: wrap}
985
986            class DB:
987                b = wrap
988
989            return (await bar() + await wrap()() + await db['b']()()() +
990                    await bar() * 1000 + await DB.b()())
991
992        async def foo2():
993            return -await bar()
994
995        self.assertEqual(run_async(foo()), ([], 42168))
996        self.assertEqual(run_async(foo2()), ([], -42))
997
998    def test_await_10(self):
999        async def baz():
1000            return 42
1001
1002        async def bar():
1003            return baz()
1004
1005        async def foo():
1006            return await (await bar())
1007
1008        self.assertEqual(run_async(foo()), ([], 42))
1009
1010    def test_await_11(self):
1011        def ident(val):
1012            return val
1013
1014        async def bar():
1015            return 'spam'
1016
1017        async def foo():
1018            return ident(val=await bar())
1019
1020        async def foo2():
1021            return await bar(), 'ham'
1022
1023        self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
1024
1025    def test_await_12(self):
1026        async def coro():
1027            return 'spam'
1028
1029        class Awaitable:
1030            def __await__(self):
1031                return coro()
1032
1033        async def foo():
1034            return await Awaitable()
1035
1036        with self.assertRaisesRegex(
1037            TypeError, r"__await__\(\) returned a coroutine"):
1038
1039            run_async(foo())
1040
1041    def test_await_13(self):
1042        class Awaitable:
1043            def __await__(self):
1044                return self
1045
1046        async def foo():
1047            return await Awaitable()
1048
1049        with self.assertRaisesRegex(
1050            TypeError, "__await__.*returned non-iterator of type"):
1051
1052            run_async(foo())
1053
1054    def test_await_14(self):
1055        class Wrapper:
1056            # Forces the interpreter to use CoroutineType.__await__
1057            def __init__(self, coro):
1058                assert coro.__class__ is types.CoroutineType
1059                self.coro = coro
1060            def __await__(self):
1061                return self.coro.__await__()
1062
1063        class FutureLike:
1064            def __await__(self):
1065                return (yield)
1066
1067        class Marker(Exception):
1068            pass
1069
1070        async def coro1():
1071            try:
1072                return await FutureLike()
1073            except ZeroDivisionError:
1074                raise Marker
1075        async def coro2():
1076            return await Wrapper(coro1())
1077
1078        c = coro2()
1079        c.send(None)
1080        with self.assertRaisesRegex(StopIteration, 'spam'):
1081            c.send('spam')
1082
1083        c = coro2()
1084        c.send(None)
1085        with self.assertRaises(Marker):
1086            c.throw(ZeroDivisionError)
1087
1088    def test_await_15(self):
1089        @types.coroutine
1090        def nop():
1091            yield
1092
1093        async def coroutine():
1094            await nop()
1095
1096        async def waiter(coro):
1097            await coro
1098
1099        coro = coroutine()
1100        coro.send(None)
1101
1102        with self.assertRaisesRegex(RuntimeError,
1103                                    "coroutine is being awaited already"):
1104            waiter(coro).send(None)
1105
1106    def test_with_1(self):
1107        class Manager:
1108            def __init__(self, name):
1109                self.name = name
1110
1111            async def __aenter__(self):
1112                await AsyncYieldFrom(['enter-1-' + self.name,
1113                                      'enter-2-' + self.name])
1114                return self
1115
1116            async def __aexit__(self, *args):
1117                await AsyncYieldFrom(['exit-1-' + self.name,
1118                                      'exit-2-' + self.name])
1119
1120                if self.name == 'B':
1121                    return True
1122
1123
1124        async def foo():
1125            async with Manager("A") as a, Manager("B") as b:
1126                await AsyncYieldFrom([('managers', a.name, b.name)])
1127                1/0
1128
1129        f = foo()
1130        result, _ = run_async(f)
1131
1132        self.assertEqual(
1133            result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
1134                     ('managers', 'A', 'B'),
1135                     'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
1136        )
1137
1138        async def foo():
1139            async with Manager("A") as a, Manager("C") as c:
1140                await AsyncYieldFrom([('managers', a.name, c.name)])
1141                1/0
1142
1143        with self.assertRaises(ZeroDivisionError):
1144            run_async(foo())
1145
1146    def test_with_2(self):
1147        class CM:
1148            def __aenter__(self):
1149                pass
1150
1151        async def foo():
1152            async with CM():
1153                pass
1154
1155        with self.assertRaisesRegex(AttributeError, '__aexit__'):
1156            run_async(foo())
1157
1158    def test_with_3(self):
1159        class CM:
1160            def __aexit__(self):
1161                pass
1162
1163        async def foo():
1164            async with CM():
1165                pass
1166
1167        with self.assertRaisesRegex(AttributeError, '__aenter__'):
1168            run_async(foo())
1169
1170    def test_with_4(self):
1171        class CM:
1172            def __enter__(self):
1173                pass
1174
1175            def __exit__(self):
1176                pass
1177
1178        async def foo():
1179            async with CM():
1180                pass
1181
1182        with self.assertRaisesRegex(AttributeError, '__aexit__'):
1183            run_async(foo())
1184
1185    def test_with_5(self):
1186        # While this test doesn't make a lot of sense,
1187        # it's a regression test for an early bug with opcodes
1188        # generation
1189
1190        class CM:
1191            async def __aenter__(self):
1192                return self
1193
1194            async def __aexit__(self, *exc):
1195                pass
1196
1197        async def func():
1198            async with CM():
1199                assert (1, ) == 1
1200
1201        with self.assertRaises(AssertionError):
1202            run_async(func())
1203
1204    def test_with_6(self):
1205        class CM:
1206            def __aenter__(self):
1207                return 123
1208
1209            def __aexit__(self, *e):
1210                return 456
1211
1212        async def foo():
1213            async with CM():
1214                pass
1215
1216        with self.assertRaisesRegex(
1217            TypeError, "object int can't be used in 'await' expression"):
1218            # it's important that __aexit__ wasn't called
1219            run_async(foo())
1220
1221    def test_with_7(self):
1222        class CM:
1223            async def __aenter__(self):
1224                return self
1225
1226            def __aexit__(self, *e):
1227                return 444
1228
1229        async def foo():
1230            async with CM():
1231                1/0
1232
1233        try:
1234            run_async(foo())
1235        except TypeError as exc:
1236            self.assertRegex(
1237                exc.args[0], "object int can't be used in 'await' expression")
1238            self.assertTrue(exc.__context__ is not None)
1239            self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1240        else:
1241            self.fail('invalid asynchronous context manager did not fail')
1242
1243
1244    def test_with_8(self):
1245        CNT = 0
1246
1247        class CM:
1248            async def __aenter__(self):
1249                return self
1250
1251            def __aexit__(self, *e):
1252                return 456
1253
1254        async def foo():
1255            nonlocal CNT
1256            async with CM():
1257                CNT += 1
1258
1259
1260        with self.assertRaisesRegex(
1261            TypeError, "object int can't be used in 'await' expression"):
1262
1263            run_async(foo())
1264
1265        self.assertEqual(CNT, 1)
1266
1267
1268    def test_with_9(self):
1269        CNT = 0
1270
1271        class CM:
1272            async def __aenter__(self):
1273                return self
1274
1275            async def __aexit__(self, *e):
1276                1/0
1277
1278        async def foo():
1279            nonlocal CNT
1280            async with CM():
1281                CNT += 1
1282
1283        with self.assertRaises(ZeroDivisionError):
1284            run_async(foo())
1285
1286        self.assertEqual(CNT, 1)
1287
1288    def test_with_10(self):
1289        CNT = 0
1290
1291        class CM:
1292            async def __aenter__(self):
1293                return self
1294
1295            async def __aexit__(self, *e):
1296                1/0
1297
1298        async def foo():
1299            nonlocal CNT
1300            async with CM():
1301                async with CM():
1302                    raise RuntimeError
1303
1304        try:
1305            run_async(foo())
1306        except ZeroDivisionError as exc:
1307            self.assertTrue(exc.__context__ is not None)
1308            self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1309            self.assertTrue(isinstance(exc.__context__.__context__,
1310                                       RuntimeError))
1311        else:
1312            self.fail('exception from __aexit__ did not propagate')
1313
1314    def test_with_11(self):
1315        CNT = 0
1316
1317        class CM:
1318            async def __aenter__(self):
1319                raise NotImplementedError
1320
1321            async def __aexit__(self, *e):
1322                1/0
1323
1324        async def foo():
1325            nonlocal CNT
1326            async with CM():
1327                raise RuntimeError
1328
1329        try:
1330            run_async(foo())
1331        except NotImplementedError as exc:
1332            self.assertTrue(exc.__context__ is None)
1333        else:
1334            self.fail('exception from __aenter__ did not propagate')
1335
1336    def test_with_12(self):
1337        CNT = 0
1338
1339        class CM:
1340            async def __aenter__(self):
1341                return self
1342
1343            async def __aexit__(self, *e):
1344                return True
1345
1346        async def foo():
1347            nonlocal CNT
1348            async with CM() as cm:
1349                self.assertIs(cm.__class__, CM)
1350                raise RuntimeError
1351
1352        run_async(foo())
1353
1354    def test_with_13(self):
1355        CNT = 0
1356
1357        class CM:
1358            async def __aenter__(self):
1359                1/0
1360
1361            async def __aexit__(self, *e):
1362                return True
1363
1364        async def foo():
1365            nonlocal CNT
1366            CNT += 1
1367            async with CM():
1368                CNT += 1000
1369            CNT += 10000
1370
1371        with self.assertRaises(ZeroDivisionError):
1372            run_async(foo())
1373        self.assertEqual(CNT, 1)
1374
1375    def test_for_1(self):
1376        aiter_calls = 0
1377
1378        class AsyncIter:
1379            def __init__(self):
1380                self.i = 0
1381
1382            async def __aiter__(self):
1383                nonlocal aiter_calls
1384                aiter_calls += 1
1385                return self
1386
1387            async def __anext__(self):
1388                self.i += 1
1389
1390                if not (self.i % 10):
1391                    await AsyncYield(self.i * 10)
1392
1393                if self.i > 100:
1394                    raise StopAsyncIteration
1395
1396                return self.i, self.i
1397
1398
1399        buffer = []
1400        async def test1():
1401            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
1402                async for i1, i2 in AsyncIter():
1403                    buffer.append(i1 + i2)
1404
1405        yielded, _ = run_async(test1())
1406        # Make sure that __aiter__ was called only once
1407        self.assertEqual(aiter_calls, 1)
1408        self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1409        self.assertEqual(buffer, [i*2 for i in range(1, 101)])
1410
1411
1412        buffer = []
1413        async def test2():
1414            nonlocal buffer
1415            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
1416                async for i in AsyncIter():
1417                    buffer.append(i[0])
1418                    if i[0] == 20:
1419                        break
1420                else:
1421                    buffer.append('what?')
1422            buffer.append('end')
1423
1424        yielded, _ = run_async(test2())
1425        # Make sure that __aiter__ was called only once
1426        self.assertEqual(aiter_calls, 2)
1427        self.assertEqual(yielded, [100, 200])
1428        self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
1429
1430
1431        buffer = []
1432        async def test3():
1433            nonlocal buffer
1434            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
1435                async for i in AsyncIter():
1436                    if i[0] > 20:
1437                        continue
1438                    buffer.append(i[0])
1439                else:
1440                    buffer.append('what?')
1441            buffer.append('end')
1442
1443        yielded, _ = run_async(test3())
1444        # Make sure that __aiter__ was called only once
1445        self.assertEqual(aiter_calls, 3)
1446        self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1447        self.assertEqual(buffer, [i for i in range(1, 21)] +
1448                                 ['what?', 'end'])
1449
1450    def test_for_2(self):
1451        tup = (1, 2, 3)
1452        refs_before = sys.getrefcount(tup)
1453
1454        async def foo():
1455            async for i in tup:
1456                print('never going to happen')
1457
1458        with self.assertRaisesRegex(
1459                TypeError, "async for' requires an object.*__aiter__.*tuple"):
1460
1461            run_async(foo())
1462
1463        self.assertEqual(sys.getrefcount(tup), refs_before)
1464
1465    def test_for_3(self):
1466        class I:
1467            def __aiter__(self):
1468                return self
1469
1470        aiter = I()
1471        refs_before = sys.getrefcount(aiter)
1472
1473        async def foo():
1474            async for i in aiter:
1475                print('never going to happen')
1476
1477        with self.assertRaisesRegex(
1478                TypeError,
1479                r"async for' received an invalid object.*__aiter.*\: I"):
1480
1481            run_async(foo())
1482
1483        self.assertEqual(sys.getrefcount(aiter), refs_before)
1484
1485    def test_for_4(self):
1486        class I:
1487            def __aiter__(self):
1488                return self
1489
1490            def __anext__(self):
1491                return ()
1492
1493        aiter = I()
1494        refs_before = sys.getrefcount(aiter)
1495
1496        async def foo():
1497            async for i in aiter:
1498                print('never going to happen')
1499
1500        with self.assertRaisesRegex(
1501                TypeError,
1502                "async for' received an invalid object.*__anext__.*tuple"):
1503
1504            run_async(foo())
1505
1506        self.assertEqual(sys.getrefcount(aiter), refs_before)
1507
1508    def test_for_5(self):
1509        class I:
1510            async def __aiter__(self):
1511                return self
1512
1513            def __anext__(self):
1514                return 123
1515
1516        async def foo():
1517            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
1518                async for i in I():
1519                    print('never going to happen')
1520
1521        with self.assertRaisesRegex(
1522                TypeError,
1523                "async for' received an invalid object.*__anext.*int"):
1524
1525            run_async(foo())
1526
1527    def test_for_6(self):
1528        I = 0
1529
1530        class Manager:
1531            async def __aenter__(self):
1532                nonlocal I
1533                I += 10000
1534
1535            async def __aexit__(self, *args):
1536                nonlocal I
1537                I += 100000
1538
1539        class Iterable:
1540            def __init__(self):
1541                self.i = 0
1542
1543            def __aiter__(self):
1544                return self
1545
1546            async def __anext__(self):
1547                if self.i > 10:
1548                    raise StopAsyncIteration
1549                self.i += 1
1550                return self.i
1551
1552        ##############
1553
1554        manager = Manager()
1555        iterable = Iterable()
1556        mrefs_before = sys.getrefcount(manager)
1557        irefs_before = sys.getrefcount(iterable)
1558
1559        async def main():
1560            nonlocal I
1561
1562            async with manager:
1563                async for i in iterable:
1564                    I += 1
1565            I += 1000
1566
1567        with warnings.catch_warnings():
1568            warnings.simplefilter("error")
1569            # Test that __aiter__ that returns an asynchronous iterator
1570            # directly does not throw any warnings.
1571            run_async(main())
1572        self.assertEqual(I, 111011)
1573
1574        self.assertEqual(sys.getrefcount(manager), mrefs_before)
1575        self.assertEqual(sys.getrefcount(iterable), irefs_before)
1576
1577        ##############
1578
1579        async def main():
1580            nonlocal I
1581
1582            async with Manager():
1583                async for i in Iterable():
1584                    I += 1
1585            I += 1000
1586
1587            async with Manager():
1588                async for i in Iterable():
1589                    I += 1
1590            I += 1000
1591
1592        run_async(main())
1593        self.assertEqual(I, 333033)
1594
1595        ##############
1596
1597        async def main():
1598            nonlocal I
1599
1600            async with Manager():
1601                I += 100
1602                async for i in Iterable():
1603                    I += 1
1604                else:
1605                    I += 10000000
1606            I += 1000
1607
1608            async with Manager():
1609                I += 100
1610                async for i in Iterable():
1611                    I += 1
1612                else:
1613                    I += 10000000
1614            I += 1000
1615
1616        run_async(main())
1617        self.assertEqual(I, 20555255)
1618
1619    def test_for_7(self):
1620        CNT = 0
1621        class AI:
1622            async def __aiter__(self):
1623                1/0
1624        async def foo():
1625            nonlocal CNT
1626            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
1627                async for i in AI():
1628                    CNT += 1
1629            CNT += 10
1630        with self.assertRaises(ZeroDivisionError):
1631            run_async(foo())
1632        self.assertEqual(CNT, 0)
1633
1634    def test_for_8(self):
1635        CNT = 0
1636        class AI:
1637            def __aiter__(self):
1638                1/0
1639        async def foo():
1640            nonlocal CNT
1641            async for i in AI():
1642                CNT += 1
1643            CNT += 10
1644        with self.assertRaises(ZeroDivisionError):
1645            with warnings.catch_warnings():
1646                warnings.simplefilter("error")
1647                # Test that if __aiter__ raises an exception it propagates
1648                # without any kind of warning.
1649                run_async(foo())
1650        self.assertEqual(CNT, 0)
1651
1652    def test_for_9(self):
1653        # Test that DeprecationWarning can safely be converted into
1654        # an exception (__aiter__ should not have a chance to raise
1655        # a ZeroDivisionError.)
1656        class AI:
1657            async def __aiter__(self):
1658                1/0
1659        async def foo():
1660            async for i in AI():
1661                pass
1662
1663        with self.assertRaises(DeprecationWarning):
1664            with warnings.catch_warnings():
1665                warnings.simplefilter("error")
1666                run_async(foo())
1667
1668    def test_for_10(self):
1669        # Test that DeprecationWarning can safely be converted into
1670        # an exception.
1671        class AI:
1672            async def __aiter__(self):
1673                pass
1674        async def foo():
1675            async for i in AI():
1676                pass
1677
1678        with self.assertRaises(DeprecationWarning):
1679            with warnings.catch_warnings():
1680                warnings.simplefilter("error")
1681                run_async(foo())
1682
1683    def test_for_11(self):
1684        class F:
1685            def __aiter__(self):
1686                return self
1687            def __anext__(self):
1688                return self
1689            def __await__(self):
1690                1 / 0
1691
1692        async def main():
1693            async for _ in F():
1694                pass
1695
1696        with self.assertRaisesRegex(TypeError,
1697                                    'an invalid object from __anext__') as c:
1698            main().send(None)
1699
1700        err = c.exception
1701        self.assertIsInstance(err.__cause__, ZeroDivisionError)
1702
1703    def test_for_12(self):
1704        class F:
1705            def __aiter__(self):
1706                return self
1707            def __await__(self):
1708                1 / 0
1709
1710        async def main():
1711            async for _ in F():
1712                pass
1713
1714        with self.assertRaisesRegex(TypeError,
1715                                    'an invalid object from __aiter__') as c:
1716            main().send(None)
1717
1718        err = c.exception
1719        self.assertIsInstance(err.__cause__, ZeroDivisionError)
1720
1721    def test_for_tuple(self):
1722        class Done(Exception): pass
1723
1724        class AIter(tuple):
1725            i = 0
1726            def __aiter__(self):
1727                return self
1728            async def __anext__(self):
1729                if self.i >= len(self):
1730                    raise StopAsyncIteration
1731                self.i += 1
1732                return self[self.i - 1]
1733
1734        result = []
1735        async def foo():
1736            async for i in AIter([42]):
1737                result.append(i)
1738            raise Done
1739
1740        with self.assertRaises(Done):
1741            foo().send(None)
1742        self.assertEqual(result, [42])
1743
1744    def test_for_stop_iteration(self):
1745        class Done(Exception): pass
1746
1747        class AIter(StopIteration):
1748            i = 0
1749            def __aiter__(self):
1750                return self
1751            async def __anext__(self):
1752                if self.i:
1753                    raise StopAsyncIteration
1754                self.i += 1
1755                return self.value
1756
1757        result = []
1758        async def foo():
1759            async for i in AIter(42):
1760                result.append(i)
1761            raise Done
1762
1763        with self.assertRaises(Done):
1764            foo().send(None)
1765        self.assertEqual(result, [42])
1766
1767    def test_comp_1(self):
1768        async def f(i):
1769            return i
1770
1771        async def run_list():
1772            return [await c for c in [f(1), f(41)]]
1773
1774        async def run_set():
1775            return {await c for c in [f(1), f(41)]}
1776
1777        async def run_dict1():
1778            return {await c: 'a' for c in [f(1), f(41)]}
1779
1780        async def run_dict2():
1781            return {i: await c for i, c in enumerate([f(1), f(41)])}
1782
1783        self.assertEqual(run_async(run_list()), ([], [1, 41]))
1784        self.assertEqual(run_async(run_set()), ([], {1, 41}))
1785        self.assertEqual(run_async(run_dict1()), ([], {1: 'a', 41: 'a'}))
1786        self.assertEqual(run_async(run_dict2()), ([], {0: 1, 1: 41}))
1787
1788    def test_comp_2(self):
1789        async def f(i):
1790            return i
1791
1792        async def run_list():
1793            return [s for c in [f(''), f('abc'), f(''), f(['de', 'fg'])]
1794                    for s in await c]
1795
1796        self.assertEqual(
1797            run_async(run_list()),
1798            ([], ['a', 'b', 'c', 'de', 'fg']))
1799
1800        async def run_set():
1801            return {d
1802                    for c in [f([f([10, 30]),
1803                                 f([20])])]
1804                    for s in await c
1805                    for d in await s}
1806
1807        self.assertEqual(
1808            run_async(run_set()),
1809            ([], {10, 20, 30}))
1810
1811        async def run_set2():
1812            return {await s
1813                    for c in [f([f(10), f(20)])]
1814                    for s in await c}
1815
1816        self.assertEqual(
1817            run_async(run_set2()),
1818            ([], {10, 20}))
1819
1820    def test_comp_3(self):
1821        async def f(it):
1822            for i in it:
1823                yield i
1824
1825        async def run_list():
1826            return [i + 1 async for i in f([10, 20])]
1827        self.assertEqual(
1828            run_async(run_list()),
1829            ([], [11, 21]))
1830
1831        async def run_set():
1832            return {i + 1 async for i in f([10, 20])}
1833        self.assertEqual(
1834            run_async(run_set()),
1835            ([], {11, 21}))
1836
1837        async def run_dict():
1838            return {i + 1: i + 2 async for i in f([10, 20])}
1839        self.assertEqual(
1840            run_async(run_dict()),
1841            ([], {11: 12, 21: 22}))
1842
1843        async def run_gen():
1844            gen = (i + 1 async for i in f([10, 20]))
1845            return [g + 100 async for g in gen]
1846        self.assertEqual(
1847            run_async(run_gen()),
1848            ([], [111, 121]))
1849
1850    def test_comp_4(self):
1851        async def f(it):
1852            for i in it:
1853                yield i
1854
1855        async def run_list():
1856            return [i + 1 async for i in f([10, 20]) if i > 10]
1857        self.assertEqual(
1858            run_async(run_list()),
1859            ([], [21]))
1860
1861        async def run_set():
1862            return {i + 1 async for i in f([10, 20]) if i > 10}
1863        self.assertEqual(
1864            run_async(run_set()),
1865            ([], {21}))
1866
1867        async def run_dict():
1868            return {i + 1: i + 2 async for i in f([10, 20]) if i > 10}
1869        self.assertEqual(
1870            run_async(run_dict()),
1871            ([], {21: 22}))
1872
1873        async def run_gen():
1874            gen = (i + 1 async for i in f([10, 20]) if i > 10)
1875            return [g + 100 async for g in gen]
1876        self.assertEqual(
1877            run_async(run_gen()),
1878            ([], [121]))
1879
1880    def test_comp_5(self):
1881        async def f(it):
1882            for i in it:
1883                yield i
1884
1885        async def run_list():
1886            return [i + 1 for pair in ([10, 20], [30, 40]) if pair[0] > 10
1887                    async for i in f(pair) if i > 30]
1888        self.assertEqual(
1889            run_async(run_list()),
1890            ([], [41]))
1891
1892    def test_comp_6(self):
1893        async def f(it):
1894            for i in it:
1895                yield i
1896
1897        async def run_list():
1898            return [i + 1 async for seq in f([(10, 20), (30,)])
1899                    for i in seq]
1900
1901        self.assertEqual(
1902            run_async(run_list()),
1903            ([], [11, 21, 31]))
1904
1905    def test_comp_7(self):
1906        async def f():
1907            yield 1
1908            yield 2
1909            raise Exception('aaa')
1910
1911        async def run_list():
1912            return [i async for i in f()]
1913
1914        with self.assertRaisesRegex(Exception, 'aaa'):
1915            run_async(run_list())
1916
1917    def test_comp_8(self):
1918        async def f():
1919            return [i for i in [1, 2, 3]]
1920
1921        self.assertEqual(
1922            run_async(f()),
1923            ([], [1, 2, 3]))
1924
1925    def test_comp_9(self):
1926        async def gen():
1927            yield 1
1928            yield 2
1929        async def f():
1930            l = [i async for i in gen()]
1931            return [i for i in l]
1932
1933        self.assertEqual(
1934            run_async(f()),
1935            ([], [1, 2]))
1936
1937    def test_comp_10(self):
1938        async def f():
1939            xx = {i for i in [1, 2, 3]}
1940            return {x: x for x in xx}
1941
1942        self.assertEqual(
1943            run_async(f()),
1944            ([], {1: 1, 2: 2, 3: 3}))
1945
1946    def test_copy(self):
1947        async def func(): pass
1948        coro = func()
1949        with self.assertRaises(TypeError):
1950            copy.copy(coro)
1951
1952        aw = coro.__await__()
1953        try:
1954            with self.assertRaises(TypeError):
1955                copy.copy(aw)
1956        finally:
1957            aw.close()
1958
1959    def test_pickle(self):
1960        async def func(): pass
1961        coro = func()
1962        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1963            with self.assertRaises((TypeError, pickle.PicklingError)):
1964                pickle.dumps(coro, proto)
1965
1966        aw = coro.__await__()
1967        try:
1968            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1969                with self.assertRaises((TypeError, pickle.PicklingError)):
1970                    pickle.dumps(aw, proto)
1971        finally:
1972            aw.close()
1973
1974    def test_fatal_coro_warning(self):
1975        # Issue 27811
1976        async def func(): pass
1977        with warnings.catch_warnings(), support.captured_stderr() as stderr:
1978            warnings.filterwarnings("error")
1979            func()
1980            support.gc_collect()
1981        self.assertIn("was never awaited", stderr.getvalue())
1982
1983
1984class CoroAsyncIOCompatTest(unittest.TestCase):
1985
1986    def test_asyncio_1(self):
1987        # asyncio cannot be imported when Python is compiled without thread
1988        # support
1989        asyncio = support.import_module('asyncio')
1990
1991        class MyException(Exception):
1992            pass
1993
1994        buffer = []
1995
1996        class CM:
1997            async def __aenter__(self):
1998                buffer.append(1)
1999                await asyncio.sleep(0.01)
2000                buffer.append(2)
2001                return self
2002
2003            async def __aexit__(self, exc_type, exc_val, exc_tb):
2004                await asyncio.sleep(0.01)
2005                buffer.append(exc_type.__name__)
2006
2007        async def f():
2008            async with CM() as c:
2009                await asyncio.sleep(0.01)
2010                raise MyException
2011            buffer.append('unreachable')
2012
2013        loop = asyncio.new_event_loop()
2014        asyncio.set_event_loop(loop)
2015        try:
2016            loop.run_until_complete(f())
2017        except MyException:
2018            pass
2019        finally:
2020            loop.close()
2021            asyncio.set_event_loop(None)
2022
2023        self.assertEqual(buffer, [1, 2, 'MyException'])
2024
2025
2026class SysSetCoroWrapperTest(unittest.TestCase):
2027
2028    def test_set_wrapper_1(self):
2029        async def foo():
2030            return 'spam'
2031
2032        wrapped = None
2033        def wrap(gen):
2034            nonlocal wrapped
2035            wrapped = gen
2036            return gen
2037
2038        self.assertIsNone(sys.get_coroutine_wrapper())
2039
2040        sys.set_coroutine_wrapper(wrap)
2041        self.assertIs(sys.get_coroutine_wrapper(), wrap)
2042        try:
2043            f = foo()
2044            self.assertTrue(wrapped)
2045
2046            self.assertEqual(run_async(f), ([], 'spam'))
2047        finally:
2048            sys.set_coroutine_wrapper(None)
2049
2050        self.assertIsNone(sys.get_coroutine_wrapper())
2051
2052        wrapped = None
2053        with silence_coro_gc():
2054            foo()
2055        self.assertFalse(wrapped)
2056
2057    def test_set_wrapper_2(self):
2058        self.assertIsNone(sys.get_coroutine_wrapper())
2059        with self.assertRaisesRegex(TypeError, "callable expected, got int"):
2060            sys.set_coroutine_wrapper(1)
2061        self.assertIsNone(sys.get_coroutine_wrapper())
2062
2063    def test_set_wrapper_3(self):
2064        async def foo():
2065            return 'spam'
2066
2067        def wrapper(coro):
2068            async def wrap(coro):
2069                return await coro
2070            return wrap(coro)
2071
2072        sys.set_coroutine_wrapper(wrapper)
2073        try:
2074            with silence_coro_gc(), self.assertRaisesRegex(
2075                RuntimeError,
2076                r"coroutine wrapper.*\.wrapper at 0x.*attempted to "
2077                r"recursively wrap .* wrap .*"):
2078
2079                foo()
2080        finally:
2081            sys.set_coroutine_wrapper(None)
2082
2083    def test_set_wrapper_4(self):
2084        @types.coroutine
2085        def foo():
2086            return 'spam'
2087
2088        wrapped = None
2089        def wrap(gen):
2090            nonlocal wrapped
2091            wrapped = gen
2092            return gen
2093
2094        sys.set_coroutine_wrapper(wrap)
2095        try:
2096            foo()
2097            self.assertIs(
2098                wrapped, None,
2099                "generator-based coroutine was wrapped via "
2100                "sys.set_coroutine_wrapper")
2101        finally:
2102            sys.set_coroutine_wrapper(None)
2103
2104
2105class CAPITest(unittest.TestCase):
2106
2107    def test_tp_await_1(self):
2108        from _testcapi import awaitType as at
2109
2110        async def foo():
2111            future = at(iter([1]))
2112            return (await future)
2113
2114        self.assertEqual(foo().send(None), 1)
2115
2116    def test_tp_await_2(self):
2117        # Test tp_await to __await__ mapping
2118        from _testcapi import awaitType as at
2119        future = at(iter([1]))
2120        self.assertEqual(next(future.__await__()), 1)
2121
2122    def test_tp_await_3(self):
2123        from _testcapi import awaitType as at
2124
2125        async def foo():
2126            future = at(1)
2127            return (await future)
2128
2129        with self.assertRaisesRegex(
2130                TypeError, "__await__.*returned non-iterator of type 'int'"):
2131            self.assertEqual(foo().send(None), 1)
2132
2133
2134if __name__=="__main__":
2135    unittest.main()
2136