diff --git a/fastapi_utils/tasks.py b/fastapi_utils/tasks.py index 1ef74e1..5c83f65 100644 --- a/fastapi_utils/tasks.py +++ b/fastapi_utils/tasks.py @@ -2,6 +2,7 @@ import asyncio import logging +import warnings from functools import wraps from traceback import format_exception from typing import Any, Callable, Coroutine, Union @@ -10,7 +11,26 @@ NoArgsNoReturnFuncT = Callable[[], None] NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]] -NoArgsNoReturnDecorator = Callable[[Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT]], NoArgsNoReturnAsyncFuncT] +ExcArgNoReturnFuncT = Callable[[Exception], None] +ExcArgNoReturnAsyncFuncT = Callable[[Exception], Coroutine[Any, Any, None]] +NoArgsNoReturnAnyFuncT = Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT] +ExcArgNoReturnAnyFuncT = Union[ExcArgNoReturnFuncT, ExcArgNoReturnAsyncFuncT] +NoArgsNoReturnDecorator = Callable[[NoArgsNoReturnAnyFuncT], NoArgsNoReturnAsyncFuncT] + + +async def _handle_func(func: NoArgsNoReturnAnyFuncT) -> None: + if asyncio.iscoroutinefunction(func): + await func() + else: + await run_in_threadpool(func) + + +async def _handle_exc(exc: Exception, on_exception: ExcArgNoReturnAnyFuncT | None) -> None: + if on_exception: + if asyncio.iscoroutinefunction(on_exception): + await on_exception(exc) + else: + await run_in_threadpool(on_exception, exc) def repeat_every( @@ -20,6 +40,8 @@ def repeat_every( logger: logging.Logger | None = None, raise_exceptions: bool = False, max_repetitions: int | None = None, + on_complete: NoArgsNoReturnAnyFuncT | None = None, + on_exception: ExcArgNoReturnAnyFuncT | None = None, ) -> NoArgsNoReturnDecorator: """ This function returns a decorator that modifies a function so it is periodically re-executed after its first call. @@ -34,47 +56,62 @@ def repeat_every( wait_first: float (default None) If not None, the function will wait for the given duration before the first call logger: Optional[logging.Logger] (default None) + Warning: This parameter is deprecated and will be removed in the 1.0 release. The logger to use to log any exceptions raised by calls to the decorated function. If not provided, exceptions will not be logged by this function (though they may be handled by the event loop). raise_exceptions: bool (default False) + Warning: This parameter is deprecated and will be removed in the 1.0 release. If True, errors raised by the decorated function will be raised to the event loop's exception handler. Note that if an error is raised, the repeated execution will stop. Otherwise, exceptions are just logged and the execution continues to repeat. See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.set_exception_handler for more info. max_repetitions: Optional[int] (default None) The maximum number of times to call the repeated function. If `None`, the function is repeated forever. + on_complete: Optional[Callable[[], None]] (default None) + A function to call after the final repetition of the decorated function. + on_exception: Optional[Callable[[Exception], None]] (default None) + A function to call when an exception is raised by the decorated function. """ - def decorator(func: NoArgsNoReturnAsyncFuncT | NoArgsNoReturnFuncT) -> NoArgsNoReturnAsyncFuncT: + def decorator(func: NoArgsNoReturnAnyFuncT) -> NoArgsNoReturnAsyncFuncT: """ Converts the decorated function into a repeated, periodically-called version of itself. """ - is_coroutine = asyncio.iscoroutinefunction(func) @wraps(func) async def wrapped() -> None: - repetitions = 0 - async def loop() -> None: - nonlocal repetitions if wait_first is not None: await asyncio.sleep(wait_first) + + repetitions = 0 while max_repetitions is None or repetitions < max_repetitions: try: - if is_coroutine: - await func() # type: ignore - else: - await run_in_threadpool(func) + await _handle_func(func) + except Exception as exc: if logger is not None: + warnings.warn( + "'logger' is to be deprecated in favor of 'on_exception' in the 1.0 release.", + DeprecationWarning, + ) formatted_exception = "".join(format_exception(type(exc), exc, exc.__traceback__)) logger.error(formatted_exception) if raise_exceptions: + warnings.warn( + "'raise_exceptions' is to be deprecated in favor of 'on_exception' in the 1.0 release.", + DeprecationWarning, + ) raise exc + await _handle_exc(exc, on_exception) + repetitions += 1 await asyncio.sleep(seconds) - await loop() + if on_complete: + await _handle_func(on_complete) + + asyncio.ensure_future(loop()) return wrapped diff --git a/poetry.lock b/poetry.lock index 052f1f9..9922307 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -656,6 +656,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1192,6 +1202,20 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "pytest-timeout" +version = "2.3.1" +description = "pytest plugin to abort hanging tests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"}, + {file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1256,6 +1280,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1529,7 +1554,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} typing-extensions = ">=4.2.0" @@ -1778,4 +1803,4 @@ session = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "33974ee1d8387bfc84980abc4c04608f33056ed76c6f07c2d3087095f8f089c5" +content-hash = "bbd87e057fea93e8060587a5d5a0f239b31f58c64eacf7d17b88c75e76f70ea4" diff --git a/pyproject.toml b/pyproject.toml index e9f1989..4d2fd09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ session = ["sqlalchemy"] [tool.poetry.group.dev.dependencies] codecov = "^2.1.13" +pytest-timeout = "^2.3.1" [tool.black] line-length = 120 diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 11310b7..2218232 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,3 +1,4 @@ +import asyncio import sys from typing import TYPE_CHECKING, NoReturn @@ -37,42 +38,108 @@ def wait_first(seconds: float) -> float: class TestRepeatEveryBase: def setup_method(self) -> None: self.counter = 0 + self.completed = asyncio.Event() def increase_counter(self) -> None: self.counter += 1 + async def increase_counter_async(self) -> None: + self.increase_counter() + + def loop_completed(self) -> None: + self.completed.set() + + async def loop_completed_async(self) -> None: + self.loop_completed() + + def kill_loop(self, exc: Exception) -> None: + self.completed.set() + raise exc + + async def kill_loop_async(self, exc: Exception) -> None: + self.kill_loop(exc) + + def continue_loop(self, exc: Exception) -> None: + return + + async def continue_loop_async(self, exc: Exception) -> None: + self.continue_loop(exc) + + def raise_exc(self) -> NoReturn: + self.increase_counter() + raise ValueError("error") + + async def raise_exc_async(self) -> NoReturn: + self.raise_exc() -class TestRepeatEveryWithSynchronousFunction(TestRepeatEveryBase): @pytest.fixture - def increase_counter_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: - return repeat_every(seconds=seconds, max_repetitions=max_repetitions)(self.increase_counter) + def increase_counter_task(self, is_async: bool, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: + decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, on_complete=self.loop_completed) + if is_async: + return decorator(self.increase_counter_async) + else: + return decorator(self.increase_counter) @pytest.fixture def wait_first_increase_counter_task( - self, seconds: float, max_repetitions: int, wait_first: float + self, is_async: bool, seconds: float, max_repetitions: int, wait_first: float ) -> NoArgsNoReturnAsyncFuncT: - decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first) - return decorator(self.increase_counter) + decorator = repeat_every( + seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first, on_complete=self.loop_completed + ) + if is_async: + return decorator(self.increase_counter_async) + else: + return decorator(self.increase_counter) - @staticmethod @pytest.fixture - def raising_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: - @repeat_every(seconds=seconds, max_repetitions=max_repetitions) - def raise_exc() -> NoReturn: - raise ValueError("error") + def stop_on_exception_task(self, is_async: bool, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: + if is_async: + decorator = repeat_every( + seconds=seconds, + max_repetitions=max_repetitions, + on_complete=self.loop_completed_async, + on_exception=self.kill_loop_async, + ) + return decorator(self.raise_exc_async) + else: + decorator = repeat_every( + seconds=seconds, + max_repetitions=max_repetitions, + on_complete=self.loop_completed, + on_exception=self.kill_loop, + ) + return decorator(self.raise_exc) - return raise_exc - - @staticmethod @pytest.fixture - def suppressed_exception_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: - @repeat_every(seconds=seconds, raise_exceptions=True) - def raise_exc() -> NoReturn: - raise ValueError("error") + def suppressed_exception_task( + self, is_async: bool, seconds: float, max_repetitions: int + ) -> NoArgsNoReturnAsyncFuncT: + if is_async: + decorator = repeat_every( + seconds=seconds, + max_repetitions=max_repetitions, + on_complete=self.loop_completed_async, + on_exception=self.continue_loop_async, + ) + return decorator(self.raise_exc_async) + else: + decorator = repeat_every( + seconds=seconds, + max_repetitions=max_repetitions, + on_complete=self.loop_completed, + on_exception=self.continue_loop, + ) + return decorator(self.raise_exc) - return raise_exc + +class TestRepeatEveryWithSynchronousFunction(TestRepeatEveryBase): + @pytest.fixture + def is_async(self) -> bool: + return False @pytest.mark.asyncio + @pytest.mark.timeout(1) @patch("asyncio.sleep") async def test_max_repetitions( self, @@ -82,73 +149,62 @@ async def test_max_repetitions( increase_counter_task: NoArgsNoReturnAsyncFuncT, ) -> None: await increase_counter_task() + await self.completed.wait() assert self.counter == max_repetitions asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True) @pytest.mark.asyncio + @pytest.mark.timeout(1) @patch("asyncio.sleep") async def test_max_repetitions_and_wait_first( self, asyncio_sleep_mock: AsyncMock, seconds: float, max_repetitions: int, - wait_first: float, wait_first_increase_counter_task: NoArgsNoReturnAsyncFuncT, ) -> None: await wait_first_increase_counter_task() + await self.completed.wait() assert self.counter == max_repetitions asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True) @pytest.mark.asyncio - async def test_raise_exceptions_false( - self, seconds: float, max_repetitions: int, raising_task: NoArgsNoReturnAsyncFuncT + @pytest.mark.timeout(1) + async def test_stop_loop_on_exc( + self, + stop_on_exception_task: NoArgsNoReturnAsyncFuncT, ) -> None: - try: - await raising_task() - except ValueError as e: - pytest.fail(f"{self.test_raise_exceptions_false.__name__} raised an exception: {e}") + await stop_on_exception_task() + await self.completed.wait() + + assert self.counter == 1 @pytest.mark.asyncio - async def test_raise_exceptions_true( - self, seconds: float, suppressed_exception_task: NoArgsNoReturnAsyncFuncT + @pytest.mark.timeout(1) + @patch("asyncio.sleep") + async def test_continue_loop_on_exc( + self, + asyncio_sleep_mock: AsyncMock, + seconds: float, + max_repetitions: int, + suppressed_exception_task: NoArgsNoReturnAsyncFuncT, ) -> None: - with pytest.raises(ValueError): - await suppressed_exception_task() - + await suppressed_exception_task() + await self.completed.wait() -class TestRepeatEveryWithAsynchronousFunction(TestRepeatEveryBase): - @pytest.fixture - def increase_counter_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: - return repeat_every(seconds=seconds, max_repetitions=max_repetitions)(self.increase_counter) - - @pytest.fixture - def wait_first_increase_counter_task( - self, seconds: float, max_repetitions: int, wait_first: float - ) -> NoArgsNoReturnAsyncFuncT: - decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first) - return decorator(self.increase_counter) - - @staticmethod - @pytest.fixture - def raising_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: - @repeat_every(seconds=seconds, max_repetitions=max_repetitions) - async def raise_exc() -> NoReturn: - raise ValueError("error") + assert self.counter == max_repetitions + asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True) - return raise_exc - @staticmethod +class TestRepeatEveryWithAsynchronousFunction(TestRepeatEveryBase): @pytest.fixture - def suppressed_exception_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT: - @repeat_every(seconds=seconds, raise_exceptions=True) - async def raise_exc() -> NoReturn: - raise ValueError("error") - - return raise_exc + def is_async(self) -> bool: + return True @pytest.mark.asyncio + @pytest.mark.timeout(1) @patch("asyncio.sleep") async def test_max_repetitions( self, @@ -158,11 +214,13 @@ async def test_max_repetitions( increase_counter_task: NoArgsNoReturnAsyncFuncT, ) -> None: await increase_counter_task() + await self.completed.wait() assert self.counter == max_repetitions asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True) @pytest.mark.asyncio + @pytest.mark.timeout(1) @patch("asyncio.sleep") async def test_max_repetitions_and_wait_first( self, @@ -172,22 +230,34 @@ async def test_max_repetitions_and_wait_first( wait_first_increase_counter_task: NoArgsNoReturnAsyncFuncT, ) -> None: await wait_first_increase_counter_task() + await self.completed.wait() assert self.counter == max_repetitions asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True) @pytest.mark.asyncio - async def test_raise_exceptions_false( - self, seconds: float, max_repetitions: int, raising_task: NoArgsNoReturnAsyncFuncT + @pytest.mark.timeout(1) + async def test_stop_loop_on_exc( + self, + stop_on_exception_task: NoArgsNoReturnAsyncFuncT, ) -> None: - try: - await raising_task() - except ValueError as e: - pytest.fail(f"{self.test_raise_exceptions_false.__name__} raised an exception: {e}") + await stop_on_exception_task() + await self.completed.wait() + + assert self.counter == 1 @pytest.mark.asyncio - async def test_raise_exceptions_true( - self, seconds: float, suppressed_exception_task: NoArgsNoReturnAsyncFuncT + @pytest.mark.timeout(1) + @patch("asyncio.sleep") + async def test_continue_loop_on_exc( + self, + asyncio_sleep_mock: AsyncMock, + seconds: float, + max_repetitions: int, + suppressed_exception_task: NoArgsNoReturnAsyncFuncT, ) -> None: - with pytest.raises(ValueError): - await suppressed_exception_task() + await suppressed_exception_task() + await self.completed.wait() + + assert self.counter == max_repetitions + asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)