diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 6e3783e6..9c54c39f 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -61,14 +61,14 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
- python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10]
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10]
include:
- os: macos-latest
- python-version: "3.8"
+ python-version: "3.9"
- os: macos-latest
python-version: "3.12"
- os: windows-latest
- python-version: "3.8"
+ python-version: "3.9"
- os: windows-latest
python-version: "3.12"
runs-on: ${{ matrix.os }}
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 81da27ee..ef7ac30c 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -3,6 +3,22 @@ Version history
This library adheres to `Semantic Versioning 2.0 `_.
+**UNRELEASED**
+
+- Dropped support for Python 3.8
+ (as `#698 `_ cannot be resolved
+ without cancel message support)
+- Fixed 100% CPU use on asyncio while waiting for an exiting task group to finish while
+ said task group is within a cancelled cancel scope
+ (`#695 `_)
+- Fixed cancel scopes on asyncio not reraising ``CancelledError`` on exit while the
+ enclosing cancel scope has been effectively cancelled
+ (`#698 `_)
+- Fixed asyncio task groups not yielding control to the event loop at exit if there were
+ no child tasks to wait on
+- Fixed inconsistent task uncancellation with asyncio cancel scopes belonging to a
+ task group when said task group has child tasks running
+
**4.5.0**
- Improved the performance of ``anyio.Lock`` and ``anyio.Semaphore`` on asyncio (even up
diff --git a/pyproject.toml b/pyproject.toml
index 3bea40c1..4e726f4e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,14 +19,13 @@ classifiers = [
"Typing :: Typed",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
-requires-python = ">= 3.8"
+requires-python = ">= 3.9"
dependencies = [
"exceptiongroup >= 1.0.2; python_version < '3.11'",
"idna >= 2.8",
@@ -128,7 +127,7 @@ show_missing = true
[tool.tox]
legacy_tox_ini = """
[tox]
-envlist = pre-commit, py38, py39, py310, py311, py312, py313, pypy3
+envlist = pre-commit, py39, py310, py311, py312, py313, pypy3
skip_missing_interpreters = true
minversion = 4.0.0
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 0d4cdf65..9342fab8 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -20,9 +20,18 @@
)
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
from collections import OrderedDict, deque
-from collections.abc import AsyncIterator, Iterable
+from collections.abc import (
+ AsyncGenerator,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Collection,
+ Coroutine,
+ Iterable,
+ Sequence,
+)
from concurrent.futures import Future
-from contextlib import suppress
+from contextlib import AbstractContextManager, suppress
from contextvars import Context, copy_context
from dataclasses import dataclass
from functools import partial, wraps
@@ -42,15 +51,7 @@
from typing import (
IO,
Any,
- AsyncGenerator,
- Awaitable,
- Callable,
- Collection,
- ContextManager,
- Coroutine,
Optional,
- Sequence,
- Tuple,
TypeVar,
cast,
)
@@ -358,6 +359,14 @@ def _task_started(task: asyncio.Task) -> bool:
#
+def is_anyio_cancellation(exc: CancelledError) -> bool:
+ return (
+ bool(exc.args)
+ and isinstance(exc.args[0], str)
+ and exc.args[0].startswith("Cancelled by cancel scope ")
+ )
+
+
class CancelScope(BaseCancelScope):
def __new__(
cls, *, deadline: float = math.inf, shield: bool = False
@@ -444,35 +453,77 @@ def __exit__(
host_task_state.cancel_scope = self._parent_scope
- # Restart the cancellation effort in the closest directly cancelled parent
- # scope if this one was shielded
- self._restart_cancellation_in_parent()
+ # Undo all cancellations done by this scope
+ if self._cancelling is not None:
+ while self._cancel_calls:
+ self._cancel_calls -= 1
+ if self._host_task.uncancel() <= self._cancelling:
+ break
- if self._cancel_called and exc_val is not None:
+ # We only swallow the exception iff it was an AnyIO CancelledError, either
+ # directly as exc_val or inside an exception group and there are no cancelled
+ # parent cancel scopes visible to us here
+ not_swallowed_exceptions = 0
+ swallow_exception = False
+ if exc_val is not None:
for exc in iterate_exceptions(exc_val):
- if isinstance(exc, CancelledError):
- self._cancelled_caught = self._uncancel(exc)
- if self._cancelled_caught:
- break
+ if self._cancel_called and isinstance(exc, CancelledError):
+ if not (swallow_exception := self._uncancel(exc)):
+ not_swallowed_exceptions += 1
+ else:
+ not_swallowed_exceptions += 1
+
+ # Restart the cancellation effort in the closest visible, cancelled parent
+ # scope if necessary
+ self._restart_cancellation_in_parent()
+ return swallow_exception and not not_swallowed_exceptions
- return self._cancelled_caught
+ @property
+ def _effectively_cancelled(self) -> bool:
+ cancel_scope: CancelScope | None = self
+ while cancel_scope is not None:
+ if cancel_scope._cancel_called:
+ return True
- return None
+ if cancel_scope.shield:
+ return False
+
+ cancel_scope = cancel_scope._parent_scope
+
+ return False
+
+ @property
+ def _parent_cancellation_is_visible_to_us(self) -> bool:
+ return (
+ self._parent_scope is not None
+ and not self.shield
+ and self._parent_scope._effectively_cancelled
+ )
def _uncancel(self, cancelled_exc: CancelledError) -> bool:
- if sys.version_info < (3, 9) or self._host_task is None:
+ if self._host_task is None:
self._cancel_calls = 0
return True
- # Undo all cancellations done by this scope
- if self._cancelling is not None:
- while self._cancel_calls:
- self._cancel_calls -= 1
- if self._host_task.uncancel() <= self._cancelling:
- return True
+ while True:
+ if is_anyio_cancellation(cancelled_exc):
+ # Only swallow the cancellation exception if it's an AnyIO cancel
+ # exception and there are no other cancel scopes down the line pending
+ # cancellation
+ self._cancelled_caught = (
+ self._effectively_cancelled
+ and not self._parent_cancellation_is_visible_to_us
+ )
+ return self._cancelled_caught
- self._cancel_calls = 0
- return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
+ # Sometimes third party frameworks catch a CancelledError and raise a new
+ # one, so as a workaround we have to look at the previous ones in
+ # __context__ too for a matching cancel message
+ if isinstance(cancelled_exc.__context__, CancelledError):
+ cancelled_exc = cancelled_exc.__context__
+ continue
+
+ return False
def _timeout(self) -> None:
if self._deadline != math.inf:
@@ -496,19 +547,17 @@ def _deliver_cancellation(self, origin: CancelScope) -> bool:
should_retry = False
current = current_task()
for task in self._tasks:
+ should_retry = True
if task._must_cancel: # type: ignore[attr-defined]
continue
# The task is eligible for cancellation if it has started
- should_retry = True
if task is not current and (task is self._host_task or _task_started(task)):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
- origin._cancel_calls += 1
- if sys.version_info >= (3, 9):
- task.cancel(f"Cancelled by cancel scope {id(origin):x}")
- else:
- task.cancel()
+ task.cancel(f"Cancelled by cancel scope {id(origin):x}")
+ if task is origin._host_task:
+ origin._cancel_calls += 1
# Deliver cancellation to child scopes that aren't shielded or running their own
# cancellation callbacks
@@ -546,17 +595,6 @@ def _restart_cancellation_in_parent(self) -> None:
scope = scope._parent_scope
- def _parent_cancelled(self) -> bool:
- # Check whether any parent has been cancelled
- cancel_scope = self._parent_scope
- while cancel_scope is not None and not cancel_scope._shield:
- if cancel_scope._cancel_called:
- return True
- else:
- cancel_scope = cancel_scope._parent_scope
-
- return False
-
def cancel(self) -> None:
if not self._cancel_called:
if self._timeout_handle:
@@ -663,38 +701,50 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
- ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)
- cancelled_exc_while_waiting_tasks: CancelledError | None = None
- while self._tasks:
- try:
- await asyncio.wait(self._tasks)
- except CancelledError as exc:
- # This task was cancelled natively; reraise the CancelledError later
- # unless this task was already interrupted by another exception
- self.cancel_scope.cancel()
- if cancelled_exc_while_waiting_tasks is None:
- cancelled_exc_while_waiting_tasks = exc
+ try:
+ if self._tasks:
+ with CancelScope() as wait_scope:
+ while self._tasks:
+ try:
+ await asyncio.wait(self._tasks)
+ except CancelledError as exc:
+ # Shield the scope against further cancellation attempts,
+ # as they're not productive (#695)
+ wait_scope.shield = True
+ self.cancel_scope.cancel()
+
+ # Set exc_val from the cancellation exception if it was
+ # previously unset. However, we should not replace a native
+ # cancellation exception with one raise by a cancel scope.
+ if exc_val is None or (
+ isinstance(exc_val, CancelledError)
+ and not is_anyio_cancellation(exc)
+ ):
+ exc_val = exc
+ else:
+ # If there are no child tasks to wait on, run at least one checkpoint
+ # anyway
+ await AsyncIOBackend.cancel_shielded_checkpoint()
- self._active = False
- if self._exceptions:
- raise BaseExceptionGroup(
- "unhandled errors in a TaskGroup", self._exceptions
- )
+ self._active = False
+ if self._exceptions:
+ raise BaseExceptionGroup(
+ "unhandled errors in a TaskGroup", self._exceptions
+ )
+ elif exc_val:
+ raise exc_val
+ except BaseException as exc:
+ if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
+ return True
- # Raise the CancelledError received while waiting for child tasks to exit,
- # unless the context manager itself was previously exited with another
- # exception, or if any of the child tasks raised an exception other than
- # CancelledError
- if cancelled_exc_while_waiting_tasks:
- if exc_val is None or ignore_exception:
- raise cancelled_exc_while_waiting_tasks
+ raise
- return ignore_exception
+ return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
def _spawn(
self,
@@ -730,7 +780,7 @@ def task_done(_task: asyncio.Task) -> None:
if not isinstance(exc, CancelledError):
self._exceptions.append(exc)
- if not self.cancel_scope._parent_cancelled():
+ if not self.cancel_scope._effectively_cancelled:
self.cancel_scope.cancel()
else:
task_status_future.set_exception(exc)
@@ -806,7 +856,7 @@ async def start(
# Threads
#
-_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
+_Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]]
class WorkerThread(Thread):
@@ -955,7 +1005,7 @@ class Process(abc.Process):
_stderr: StreamReaderWrapper | None
async def aclose(self) -> None:
- with CancelScope(shield=True):
+ with CancelScope(shield=True) as scope:
if self._stdin:
await self._stdin.aclose()
if self._stdout:
@@ -963,14 +1013,14 @@ async def aclose(self) -> None:
if self._stderr:
await self._stderr.aclose()
- try:
- await self.wait()
- except BaseException:
- self.kill()
- with CancelScope(shield=True):
+ scope.shield = False
+ try:
await self.wait()
-
- raise
+ except BaseException:
+ scope.shield = True
+ self.kill()
+ await self.wait()
+ raise
async def wait(self) -> int:
return await self._process.wait()
@@ -2015,9 +2065,7 @@ def has_pending_cancellation(self) -> bool:
if task_state := _task_states.get(task):
if cancel_scope := task_state.cancel_scope:
- return cancel_scope.cancel_called or (
- not cancel_scope.shield and cancel_scope._parent_cancelled()
- )
+ return cancel_scope._effectively_cancelled
return False
@@ -2111,7 +2159,7 @@ async def _call_in_runner_task(
) -> T_Retval:
if not self._runner_task:
self._send_stream, receive_stream = create_memory_object_stream[
- Tuple[Awaitable[Any], asyncio.Future]
+ tuple[Awaitable[Any], asyncio.Future]
](1)
self._runner_task = self.get_loop().create_task(
self._run_tests_and_fixtures(receive_stream)
@@ -2473,7 +2521,7 @@ async def connect_tcp(
cls, host: str, port: int, local_address: IPSockAddrType | None = None
) -> abc.SocketStream:
transport, protocol = cast(
- Tuple[asyncio.Transport, StreamProtocol],
+ tuple[asyncio.Transport, StreamProtocol],
await get_running_loop().create_connection(
StreamProtocol, host, port, local_addr=local_address
),
@@ -2652,7 +2700,7 @@ def current_default_thread_limiter(cls) -> CapacityLimiter:
@classmethod
def open_signal_receiver(
cls, *signals: Signals
- ) -> ContextManager[AsyncIterator[Signals]]:
+ ) -> AbstractContextManager[AsyncIterator[Signals]]:
return _SignalReceiver(signals)
@classmethod
diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py
index 9b8369d4..de2189ce 100644
--- a/src/anyio/_backends/_trio.py
+++ b/src/anyio/_backends/_trio.py
@@ -7,8 +7,18 @@
import sys
import types
import weakref
-from collections.abc import AsyncIterator, Iterable
+from collections.abc import (
+ AsyncGenerator,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Collection,
+ Coroutine,
+ Iterable,
+ Sequence,
+)
from concurrent.futures import Future
+from contextlib import AbstractContextManager
from dataclasses import dataclass
from functools import partial
from io import IOBase
@@ -19,15 +29,8 @@
from typing import (
IO,
Any,
- AsyncGenerator,
- Awaitable,
- Callable,
- Collection,
- ContextManager,
- Coroutine,
Generic,
NoReturn,
- Sequence,
TypeVar,
cast,
overload,
@@ -1273,7 +1276,7 @@ def current_default_thread_limiter(cls) -> CapacityLimiter:
@classmethod
def open_signal_receiver(
cls, *signals: Signals
- ) -> ContextManager[AsyncIterator[Signals]]:
+ ) -> AbstractContextManager[AsyncIterator[Signals]]:
return _SignalReceiver(signals)
@classmethod
diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py
index 9503d944..23ccb0d6 100644
--- a/src/anyio/_core/_fileio.py
+++ b/src/anyio/_core/_fileio.py
@@ -3,7 +3,7 @@
import os
import pathlib
import sys
-from collections.abc import Callable, Iterable, Iterator, Sequence
+from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from functools import partial
from os import PathLike
@@ -12,7 +12,6 @@
TYPE_CHECKING,
Any,
AnyStr,
- AsyncIterator,
Final,
Generic,
overload,
diff --git a/src/anyio/_core/_signals.py b/src/anyio/_core/_signals.py
index 115c749b..f3451d30 100644
--- a/src/anyio/_core/_signals.py
+++ b/src/anyio/_core/_signals.py
@@ -1,13 +1,15 @@
from __future__ import annotations
from collections.abc import AsyncIterator
+from contextlib import AbstractContextManager
from signal import Signals
-from typing import ContextManager
from ._eventloop import get_async_backend
-def open_signal_receiver(*signals: Signals) -> ContextManager[AsyncIterator[Signals]]:
+def open_signal_receiver(
+ *signals: Signals,
+) -> AbstractContextManager[AsyncIterator[Signals]]:
"""
Start receiving operating system signals.
diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py
index aa6b0c22..6a9814e5 100644
--- a/src/anyio/_core/_streams.py
+++ b/src/anyio/_core/_streams.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import math
-from typing import Tuple, TypeVar
+from typing import TypeVar
from warnings import warn
from ..streams.memory import (
@@ -14,7 +14,7 @@
class create_memory_object_stream(
- Tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]],
+ tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]],
):
"""
Create a memory object stream.
diff --git a/src/anyio/_core/_subprocesses.py b/src/anyio/_core/_subprocesses.py
index 1ac2d549..7ba41a5b 100644
--- a/src/anyio/_core/_subprocesses.py
+++ b/src/anyio/_core/_subprocesses.py
@@ -160,38 +160,25 @@ async def open_process(
child process prior to the execution of the subprocess. (POSIX only)
:param pass_fds: sequence of file descriptors to keep open between the parent and
child processes. (POSIX only)
- :param user: effective user to run the process as (Python >= 3.9; POSIX only)
- :param group: effective group to run the process as (Python >= 3.9; POSIX only)
- :param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9;
- POSIX only)
+ :param user: effective user to run the process as (POSIX only)
+ :param group: effective group to run the process as (POSIX only)
+ :param extra_groups: supplementary groups to set in the subprocess (POSIX only)
:param umask: if not negative, this umask is applied in the child process before
- running the given command (Python >= 3.9; POSIX only)
+ running the given command (POSIX only)
:return: an asynchronous process object
"""
kwargs: dict[str, Any] = {}
if user is not None:
- if sys.version_info < (3, 9):
- raise TypeError("the 'user' argument requires Python 3.9 or later")
-
kwargs["user"] = user
if group is not None:
- if sys.version_info < (3, 9):
- raise TypeError("the 'group' argument requires Python 3.9 or later")
-
kwargs["group"] = group
if extra_groups is not None:
- if sys.version_info < (3, 9):
- raise TypeError("the 'extra_groups' argument requires Python 3.9 or later")
-
kwargs["extra_groups"] = group
if umask >= 0:
- if sys.version_info < (3, 9):
- raise TypeError("the 'umask' argument requires Python 3.9 or later")
-
kwargs["umask"] = umask
return await get_async_backend().open_process(
diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py
index 2c73bb9f..93d0e9d2 100644
--- a/src/anyio/abc/_eventloop.py
+++ b/src/anyio/abc/_eventloop.py
@@ -3,7 +3,8 @@
import math
import sys
from abc import ABCMeta, abstractmethod
-from collections.abc import AsyncIterator, Awaitable
+from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
+from contextlib import AbstractContextManager
from os import PathLike
from signal import Signals
from socket import AddressFamily, SocketKind, socket
@@ -11,9 +12,6 @@
IO,
TYPE_CHECKING,
Any,
- Callable,
- ContextManager,
- Sequence,
TypeVar,
Union,
overload,
@@ -352,7 +350,7 @@ def current_default_thread_limiter(cls) -> CapacityLimiter:
@abstractmethod
def open_signal_receiver(
cls, *signals: Signals
- ) -> ContextManager[AsyncIterator[Signals]]:
+ ) -> AbstractContextManager[AsyncIterator[Signals]]:
pass
@classmethod
diff --git a/src/anyio/abc/_sockets.py b/src/anyio/abc/_sockets.py
index b321225a..1c6a450c 100644
--- a/src/anyio/abc/_sockets.py
+++ b/src/anyio/abc/_sockets.py
@@ -8,7 +8,7 @@
from ipaddress import IPv4Address, IPv6Address
from socket import AddressFamily
from types import TracebackType
-from typing import Any, Tuple, TypeVar, Union
+from typing import Any, TypeVar, Union
from .._core._typedattr import (
TypedAttributeProvider,
@@ -19,10 +19,10 @@
from ._tasks import TaskGroup
IPAddressType = Union[str, IPv4Address, IPv6Address]
-IPSockAddrType = Tuple[str, int]
+IPSockAddrType = tuple[str, int]
SockAddrType = Union[IPSockAddrType, str]
-UDPPacketType = Tuple[bytes, IPSockAddrType]
-UNIXDatagramPacketType = Tuple[bytes, str]
+UDPPacketType = tuple[bytes, IPSockAddrType]
+UNIXDatagramPacketType = tuple[bytes, str]
T_Retval = TypeVar("T_Retval")
diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py
index b8785845..93a4cfe8 100644
--- a/src/anyio/from_thread.py
+++ b/src/anyio/from_thread.py
@@ -3,15 +3,17 @@
import sys
from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import Future
-from contextlib import AbstractContextManager, contextmanager
+from contextlib import (
+ AbstractAsyncContextManager,
+ AbstractContextManager,
+ contextmanager,
+)
from dataclasses import dataclass, field
from inspect import isawaitable
from threading import Lock, Thread, get_ident
from types import TracebackType
from typing import (
Any,
- AsyncContextManager,
- ContextManager,
Generic,
TypeVar,
cast,
@@ -87,7 +89,9 @@ class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
type[BaseException] | None, BaseException | None, TracebackType | None
] = (None, None, None)
- def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal):
+ def __init__(
+ self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
+ ):
self._async_cm = async_cm
self._portal = portal
@@ -374,8 +378,8 @@ def task_done(future: Future[T_Retval]) -> None:
return f, task_status_future.result()
def wrap_async_context_manager(
- self, cm: AsyncContextManager[T_co]
- ) -> ContextManager[T_co]:
+ self, cm: AbstractAsyncContextManager[T_co]
+ ) -> AbstractContextManager[T_co]:
"""
Wrap an async context manager as a synchronous context manager via this portal.
diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py
index 558c72ec..c9fe1bde 100644
--- a/src/anyio/pytest_plugin.py
+++ b/src/anyio/pytest_plugin.py
@@ -4,7 +4,7 @@
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from inspect import isasyncgenfunction, iscoroutinefunction
-from typing import Any, Dict, Tuple, cast
+from typing import Any, cast
import pytest
import sniffio
@@ -27,7 +27,7 @@ def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
return backend, {}
elif isinstance(backend, tuple) and len(backend) == 2:
if isinstance(backend[0], str) and isinstance(backend[1], dict):
- return cast(Tuple[str, Dict[str, Any]], backend)
+ return cast(tuple[str, dict[str, Any]], backend)
raise TypeError("anyio_backend must be either a string or tuple of (string, dict)")
diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py
index e913eedb..83240b4d 100644
--- a/src/anyio/streams/tls.py
+++ b/src/anyio/streams/tls.py
@@ -7,7 +7,7 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import wraps
-from typing import Any, Tuple, TypeVar
+from typing import Any, TypeVar
from .. import (
BrokenResourceError,
@@ -25,8 +25,8 @@
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
-_PCTRTT = Tuple[Tuple[str, str], ...]
-_PCTRTTT = Tuple[_PCTRTT, ...]
+_PCTRTT = tuple[tuple[str, str], ...]
+_PCTRTTT = tuple[_PCTRTT, ...]
class TLSAttribute(TypedAttributeSet):
diff --git a/tests/streams/test_stapled.py b/tests/streams/test_stapled.py
index d7614314..b032e215 100644
--- a/tests/streams/test_stapled.py
+++ b/tests/streams/test_stapled.py
@@ -1,8 +1,9 @@
from __future__ import annotations
from collections import deque
+from collections.abc import Iterable
from dataclasses import InitVar, dataclass, field
-from typing import Iterable, TypeVar
+from typing import TypeVar
import pytest
diff --git a/tests/streams/test_tls.py b/tests/streams/test_tls.py
index 9846e0c1..90307657 100644
--- a/tests/streams/test_tls.py
+++ b/tests/streams/test_tls.py
@@ -2,9 +2,9 @@
import socket
import ssl
-from contextlib import ExitStack
+from contextlib import AbstractContextManager, ExitStack
from threading import Thread
-from typing import ContextManager, NoReturn
+from typing import NoReturn
import pytest
from pytest_mock import MockerFixture
@@ -210,7 +210,7 @@ def serve_sync() -> None:
finally:
conn.close()
- client_cm: ContextManager = ExitStack()
+ client_cm: AbstractContextManager = ExitStack()
if client_compatible and not server_compatible:
client_cm = pytest.raises(BrokenResourceError)
diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py
index f69f513d..c37614e7 100644
--- a/tests/test_from_thread.py
+++ b/tests/test_from_thread.py
@@ -4,12 +4,12 @@
import sys
import threading
import time
-from collections.abc import Awaitable, Callable
+from collections.abc import AsyncGenerator, Awaitable, Callable
from concurrent import futures
from concurrent.futures import CancelledError, Future
from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar
-from typing import Any, AsyncGenerator, Literal, NoReturn, TypeVar
+from typing import Any, Literal, NoReturn, TypeVar
import pytest
import sniffio
diff --git a/tests/test_signals.py b/tests/test_signals.py
index 16861b82..161633d2 100644
--- a/tests/test_signals.py
+++ b/tests/test_signals.py
@@ -3,7 +3,7 @@
import os
import signal
import sys
-from typing import AsyncIterable
+from collections.abc import AsyncIterable
import pytest
diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index 832ae6bc..42937a36 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -10,12 +10,13 @@
import tempfile
import threading
import time
+from collections.abc import Generator, Iterable, Iterator
from contextlib import suppress
from pathlib import Path
from socket import AddressFamily
from ssl import SSLContext, SSLError
from threading import Thread
-from typing import Any, Generator, Iterable, Iterator, NoReturn, TypeVar, cast
+from typing import Any, NoReturn, TypeVar, cast
import psutil
import pytest
@@ -1158,9 +1159,10 @@ async def handle(stream: SocketStream) -> None:
async with stream:
await stream.send(b"Hello\n")
- async with await create_unix_listener(
- socket_path
- ) as listener, create_task_group() as tg:
+ async with (
+ await create_unix_listener(socket_path) as listener,
+ create_task_group() as tg,
+ ):
tg.start_soon(listener.serve, handle)
await wait_all_tasks_blocked()
diff --git a/tests/test_subprocesses.py b/tests/test_subprocesses.py
index b1ff553d..adf029a3 100644
--- a/tests/test_subprocesses.py
+++ b/tests/test_subprocesses.py
@@ -4,7 +4,6 @@
import platform
import sys
from collections.abc import Callable
-from contextlib import ExitStack
from pathlib import Path
from subprocess import CalledProcessError
from textwrap import dedent
@@ -135,9 +134,11 @@ async def test_run_process_connect_to_file(tmp_path: Path) -> None:
stdinfile.write_text("Hello, process!\n")
stdoutfile = tmp_path / "stdout"
stderrfile = tmp_path / "stderr"
- with stdinfile.open("rb") as fin, stdoutfile.open("wb") as fout, stderrfile.open(
- "wb"
- ) as ferr:
+ with (
+ stdinfile.open("rb") as fin,
+ stdoutfile.open("wb") as fout,
+ stderrfile.open("wb") as ferr,
+ ):
async with await open_process(
[
sys.executable,
@@ -271,30 +272,21 @@ async def test_py39_arguments(
anyio_backend_name: str,
anyio_backend_options: dict[str, Any],
) -> None:
- with ExitStack() as stack:
- if sys.version_info < (3, 9):
- stack.enter_context(
- pytest.raises(
- TypeError,
- match=rf"the {argname!r} argument requires Python 3.9 or later",
- )
- )
-
- try:
- await run_process(
- [sys.executable, "-c", "print('hello')"],
- **{argname: argvalue_factory()},
- )
- except ValueError as exc:
- if (
- "unexpected kwargs" in str(exc)
- and anyio_backend_name == "asyncio"
- and anyio_backend_options["loop_factory"]
- and anyio_backend_options["loop_factory"].__module__ == "uvloop"
- ):
- pytest.skip(f"the {argname!r} argument is not supported by uvloop yet")
+ try:
+ await run_process(
+ [sys.executable, "-c", "print('hello')"],
+ **{argname: argvalue_factory()},
+ )
+ except ValueError as exc:
+ if (
+ "unexpected kwargs" in str(exc)
+ and anyio_backend_name == "asyncio"
+ and anyio_backend_options["loop_factory"]
+ and anyio_backend_options["loop_factory"].__module__ == "uvloop"
+ ):
+ pytest.skip(f"the {argname!r} argument is not supported by uvloop yet")
- raise
+ raise
async def test_close_early() -> None:
@@ -316,9 +308,10 @@ async def test_close_while_reading() -> None:
time.sleep(3)
""")
- async with await open_process(
- [sys.executable, "-c", code]
- ) as process, create_task_group() as tg:
+ async with (
+ await open_process([sys.executable, "-c", code]) as process,
+ create_task_group() as tg,
+ ):
assert process.stdout
tg.start_soon(process.stdout.aclose)
with pytest.raises(ClosedResourceError):
diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py
index a4603612..31490572 100644
--- a/tests/test_taskgroups.py
+++ b/tests/test_taskgroups.py
@@ -4,11 +4,13 @@
import math
import sys
import time
+from asyncio import CancelledError
from collections.abc import AsyncGenerator, Coroutine, Generator
from typing import Any, NoReturn, cast
import pytest
from exceptiongroup import catch
+from pytest_mock import MockerFixture
import anyio
from anyio import (
@@ -257,6 +259,36 @@ async def taskfunc() -> None:
await task
+@pytest.mark.parametrize("anyio_backend", ["asyncio"])
+async def test_cancel_with_nested_task_groups(mocker: MockerFixture) -> None:
+ """Regression test for #695."""
+
+ async def shield_task() -> None:
+ with CancelScope(shield=True) as scope:
+ shielded_cancel_spy = mocker.spy(scope, "_deliver_cancellation")
+ await sleep(0.5)
+
+ assert len(outer_cancel_spy.call_args_list) < 10
+ shielded_cancel_spy.assert_not_called()
+
+ async def middle_task() -> None:
+ try:
+ async with create_task_group() as tg:
+ middle_cancel_spy = mocker.spy(tg.cancel_scope, "_deliver_cancellation")
+ tg.start_soon(shield_task, name="shield task")
+ finally:
+ assert len(middle_cancel_spy.call_args_list) < 10
+ assert len(outer_cancel_spy.call_args_list) < 10
+
+ async with create_task_group() as tg:
+ outer_cancel_spy = mocker.spy(tg.cancel_scope, "_deliver_cancellation")
+ tg.start_soon(middle_task, name="middle task")
+ await wait_all_tasks_blocked()
+ tg.cancel_scope.cancel()
+
+ assert len(outer_cancel_spy.call_args_list) < 10
+
+
async def test_start_exception_delivery(anyio_backend_name: str) -> None:
def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
task_status.started("hello")
@@ -394,7 +426,7 @@ async def g() -> NoReturn:
async with create_task_group():
await sleep(1)
- assert False
+ pytest.fail("Execution should not reach this point")
async with create_task_group() as tg:
tg.start_soon(g)
@@ -455,19 +487,6 @@ async def test_cancel_before_entering_scope() -> None:
pytest.fail("execution should not reach this point")
-@pytest.mark.xfail(
- sys.version_info < (3, 11), reason="Requires asyncio.Task.cancelling()"
-)
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_cancel_counter_nested_scopes() -> None:
- with CancelScope() as root_scope:
- with CancelScope():
- root_scope.cancel()
- await sleep(0.5)
-
- assert not cast(asyncio.Task, asyncio.current_task()).cancelling()
-
-
async def test_exception_group_children() -> None:
with pytest.raises(BaseExceptionGroup) as exc:
async with create_task_group() as tg:
@@ -660,17 +679,92 @@ async def test_cancelled_not_caught() -> None:
assert not scope.cancelled_caught
+async def test_cancelled_scope_based_checkpoint() -> None:
+ """Regression test closely related to #698."""
+ with CancelScope() as outer_scope:
+ outer_scope.cancel()
+
+ # The following three lines are a way to implement a checkpoint function.
+ # See also https://github.com/python-trio/trio/issues/860.
+ with CancelScope() as inner_scope:
+ inner_scope.cancel()
+ await sleep_forever()
+
+ pytest.fail("checkpoint should have raised")
+
+ assert not inner_scope.cancelled_caught
+ assert outer_scope.cancelled_caught
+
+
+async def test_cancelled_raises_beyond_origin_unshielded() -> None:
+ with CancelScope() as outer_scope:
+ with CancelScope() as inner_scope:
+ inner_scope.cancel()
+ try:
+ await checkpoint()
+ finally:
+ outer_scope.cancel()
+
+ pytest.fail("checkpoint should have raised")
+
+ pytest.fail("exiting the inner scope should've raised a cancellation error")
+
+ # Here, the outer scope is responsible for the cancellation, so the inner scope
+ # won't catch the cancellation exception, but the outer scope will
+ assert not inner_scope.cancelled_caught
+ assert outer_scope.cancelled_caught
+
+
+async def test_cancelled_raises_beyond_origin_shielded() -> None:
+ code_between_scopes_was_run = False
+ with CancelScope() as outer_scope:
+ with CancelScope(shield=True) as inner_scope:
+ inner_scope.cancel()
+ try:
+ await checkpoint()
+ finally:
+ outer_scope.cancel()
+
+ pytest.fail("checkpoint should have raised")
+
+ code_between_scopes_was_run = True
+
+ # Here, the inner scope is the one responsible for cancellation, and given that the
+ # outer scope was also cancelled, it is not considered to have "caught" the
+ # cancellation, even though it swallows it, because the inner scope triggered it
+ assert code_between_scopes_was_run
+ assert inner_scope.cancelled_caught
+ assert not outer_scope.cancelled_caught
+
+
+async def test_empty_taskgroup_contains_yield_point() -> None:
+ """
+ Test that a task group yields at exit at least once, even with no child tasks to
+ wait on.
+
+ """
+ outer_task_ran = False
+
+ async def outer_task() -> None:
+ nonlocal outer_task_ran
+ outer_task_ran = True
+
+ async with create_task_group() as tg_outer:
+ for _ in range(2): # this is to make sure Trio actually schedules outer_task()
+ async with create_task_group():
+ tg_outer.start_soon(outer_task)
+
+ assert outer_task_ran
+
+
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_cancel_host_asyncgen() -> None:
done = False
async def host_task() -> None:
nonlocal done
- async with create_task_group() as tg:
- with CancelScope(shield=True) as inner_scope:
- assert inner_scope.shield
- tg.cancel_scope.cancel()
- await checkpoint()
+ with CancelScope() as inner_scope:
+ inner_scope.cancel()
with pytest.raises(get_cancelled_exc_class()):
await checkpoint()
@@ -780,12 +874,12 @@ async def child(fail: bool) -> None:
async def test_cancel_cascade() -> None:
async def do_something() -> NoReturn:
async with create_task_group() as tg2:
- tg2.start_soon(sleep, 1)
+ tg2.start_soon(sleep, 1, name="sleep")
- raise Exception("foo")
+ pytest.fail("Execution should not reach this point")
async with create_task_group() as tg:
- tg.start_soon(do_something)
+ tg.start_soon(do_something, name="do_something")
await wait_all_tasks_blocked()
tg.cancel_scope.cancel()
@@ -970,7 +1064,7 @@ async def g() -> NoReturn:
tg2.start_soon(anyio.sleep, 10)
await anyio.sleep(1)
- assert False
+ pytest.fail("Execution should not have reached this line")
async with anyio.create_task_group() as tg:
tg.start_soon(g)
@@ -1321,6 +1415,77 @@ async def test_cancel_message_replaced(self) -> None:
except asyncio.CancelledError:
pytest.fail("Should have swallowed the CancelledError")
+ async def test_cancel_counter_nested_scopes(self) -> None:
+ with CancelScope() as root_scope:
+ with CancelScope():
+ root_scope.cancel()
+ await checkpoint()
+
+ assert not cast(asyncio.Task, asyncio.current_task()).cancelling()
+
+ async def test_uncancel_after_taskgroup_cancelled(self) -> None:
+ """
+ Test that a cancel scope only uncancels the host task as many times as it has
+ cancelled that specific task, and won't count child task cancellations towards
+ that amount.
+ """
+
+ async def child_task(task_status: TaskStatus[None]) -> None:
+ async with create_task_group() as tg:
+ tg.start_soon(sleep, 3)
+ await wait_all_tasks_blocked()
+ task_status.started()
+
+ task = asyncio.current_task()
+ assert task
+ with pytest.raises(CancelledError):
+ async with create_task_group() as tg:
+ await tg.start(child_task)
+ task.cancel()
+
+ assert task.cancelling() == 1
+
+ async def test_uncancel_after_group_aexit_native_cancel(self) -> None:
+ """Closely related to #695."""
+ done = anyio.Event()
+
+ async def shield_task() -> None:
+ with CancelScope(shield=True):
+ await done.wait()
+
+ async def middle_task() -> None:
+ async with create_task_group() as tg:
+ tg.start_soon(shield_task)
+
+ task = asyncio.get_running_loop().create_task(middle_task())
+ try:
+ await wait_all_tasks_blocked()
+ task.cancel("native 1")
+ await sleep(0.1)
+ task.cancel("native 2")
+ finally:
+ done.set()
+
+ with pytest.raises(asyncio.CancelledError) as exc:
+ await task
+
+ # Neither native cancellation should have been uncancelled, and the latest
+ # cancellation message should be the one coming out of the task group.
+ assert task.cancelling() == 2
+ assert str(exc.value) == "native 2"
+
+ async def test_uncancel_after_child_task_failed(self) -> None:
+ async def taskfunc() -> None:
+ raise Exception("dummy error")
+
+ with pytest.raises(ExceptionGroup) as exc_info:
+ async with create_task_group() as tg:
+ tg.start_soon(taskfunc)
+
+ assert len(exc_info.value.exceptions) == 1
+ assert str(exc_info.value.exceptions[0]) == "dummy error"
+ assert not cast(asyncio.Task, asyncio.current_task()).cancelling()
+
async def test_cancel_before_entering_task_group() -> None:
with CancelScope() as scope:
diff --git a/tests/test_typedattr.py b/tests/test_typedattr.py
index 9930996a..48e175d5 100644
--- a/tests/test_typedattr.py
+++ b/tests/test_typedattr.py
@@ -1,6 +1,7 @@
from __future__ import annotations
-from typing import Any, Callable, Mapping
+from collections.abc import Mapping
+from typing import Any, Callable
import pytest