# Copyright 2022 Amethyst Reese # Licensed under the MIT license import asyncio import functools import sys from unittest import skipIf, TestCase from aioitertools.helpers import maybe_await def async_test(fn): def wrapped(*args, **kwargs): try: loop = asyncio.new_event_loop() loop.set_debug(False) result = loop.run_until_complete(fn(*args, **kwargs)) return result finally: loop.close() return wrapped class HelpersTest(TestCase): # aioitertools.helpers.maybe_await() @async_test async def test_maybe_await(self): self.assertEqual(await maybe_await(42), 42) @async_test async def test_maybe_await_async_def(self): async def forty_two(): await asyncio.sleep(0.0001) return 42 self.assertEqual(await maybe_await(forty_two()), 42) @skipIf(sys.version_info >= (3, 11), "@asyncio.coroutine removed") @async_test async def test_maybe_await_coroutine(self): @asyncio.coroutine def forty_two(): yield from asyncio.sleep(0.0001) return 42 self.assertEqual(await maybe_await(forty_two()), 42) @async_test async def test_maybe_await_partial(self): async def multiply(a, b): await asyncio.sleep(0.0001) return a * b self.assertEqual(await maybe_await(functools.partial(multiply, 6)(7)), 42)