diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 0faad40c2d7f52..4f2278bb263681 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -1798,12 +1798,6 @@ async def run(): res = self.loop.run_until_complete(run()) self.assertEqual(res, [i * 2 for i in range(1, 10)]) - def test_async_gen_expression_incorrect(self): - err_msg = "'async for' requires an object with " \ - "__aiter__ method, got int" - with self.assertRaisesRegex(TypeError, err_msg): - g = (x async for x in 42) - def test_asyncgen_nonstarted_hooks_are_cancellable(self): # See https://bugs.python.org/issue38013 messages = [] diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py index bf2cb1160723b0..1bcf3fc2750d32 100644 --- a/Lib/test/test_generators.py +++ b/Lib/test/test_generators.py @@ -341,6 +341,104 @@ def get_generator_genfunc(obj): self.process_tests(get_generator_genfunc) +class SequenceClass: + def __init__(self, n): + self.n = n + def __getitem__(self, i): + if 0 <= i < self.n: + return i + else: + raise IndexError + + +class IncorrectIterable: + def __iter__(self): + return 123 + + +class IncorrectAsyncIterable: + def __aiter__(self): + return 123 + + +class CheckIterableTest(unittest.TestCase): + sequences = ( + (1, 2), + [1, 2, 3], + range(42), + SequenceClass(10), + ) + + non_iterables = ( + None, + 42, + 13.0, + ) + + err_msg_sync = "'.*' object is not iterable" + err_msg_async = "'async for' requires an object with " \ + "__aiter__ method, got .*" + + def test_sequences(self): + for seq in self.sequences: + (x for x in seq) + (x for x in iter(seq)) + with self.assertRaisesRegex(TypeError, self.err_msg_async): + (x async for x in seq) + with self.assertRaisesRegex(TypeError, self.err_msg_async): + (x async for x in iter(seq)) + + def test_non_iterables(self): + for obj in self.non_iterables: + with self.assertRaisesRegex(TypeError, self.err_msg_sync): + (x for x in obj) + with self.assertRaisesRegex(TypeError, self.err_msg_async): + (x async for x in obj) + + def test_generators(self): + def gen(): + yield 1 + + (x for x in gen()) + (x for x in iter(gen())) + + with self.assertRaisesRegex(TypeError, self.err_msg_async): + (x async for x in gen()) + with self.assertRaisesRegex(TypeError, self.err_msg_async): + (x async for x in iter(gen())) + + def test_async_generators(self): + async def agen(): + yield 1 + yield 2 + + with self.assertRaisesRegex(TypeError, self.err_msg_sync): + (x for x in agen()) + with self.assertRaisesRegex(TypeError, self.err_msg_sync): + (x for x in aiter(agen())) + + (x async for x in agen()) + (x async for x in aiter(agen())) + + def test_incorrect_iterable(self): + g = (x for x in IncorrectIterable()) + err_msg = ".* returned non-iterator of type '.*'" + with self.assertRaisesRegex(TypeError, err_msg): + list(g) + + def test_incorrect_async_iterable(self): + g = (x async for x in IncorrectAsyncIterable()) + + async def coroutine(): + async for x in g: + pass + + err_msg = "'async for' received an object from __aiter__ " \ + "that does not implement __anext__: .*" + with self.assertRaisesRegex(TypeError, err_msg): + coroutine().send(None) + + class ExceptionTest(unittest.TestCase): # Tests for the issue #23353: check that the currently handled exception # is correctly saved/restored in PyEval_EvalFrameEx().