diff --git a/uvloop/loop.pyi b/uvloop/loop.pyi index 9c8c4623..0ab8b7cb 100644 --- a/uvloop/loop.pyi +++ b/uvloop/loop.pyi @@ -11,6 +11,7 @@ from typing import ( Generator, List, Optional, + Protocol, Sequence, Tuple, TypeVar, @@ -24,6 +25,14 @@ _ExceptionHandler = Callable[[asyncio.AbstractEventLoop, _Context], Any] _SSLContext = Union[bool, None, ssl.SSLContext] _ProtocolT = TypeVar("_ProtocolT", bound=asyncio.BaseProtocol) + +class TaskFactoryCallable(Protocol): + def __call__( + self, loop: asyncio.AbstractEventLoop, coro: Generator[Any, None, _T], **kwargs: Any + ) -> asyncio.Future[_T]: + ... + + class Loop: def call_soon( self, callback: Callable[..., Any], *args: Any, context: Optional[Any] = ... @@ -52,17 +61,8 @@ class Loop: *, name: Optional[str] = ..., ) -> asyncio.Task[_T]: ... - def set_task_factory( - self, - factory: Optional[ - Callable[[asyncio.AbstractEventLoop, Generator[Any, None, _T]], asyncio.Future[_T]] - ], - ) -> None: ... - def get_task_factory( - self, - ) -> Optional[ - Callable[[asyncio.AbstractEventLoop, Generator[Any, None, _T]], asyncio.Future[_T]] - ]: ... + def set_task_factory(self, factory: Optional[TaskFactoryCallable]) -> None: ... + def get_task_factory(self) -> Optional[TaskFactoryCallable]: ... @overload def run_until_complete(self, future: Generator[Any, None, _T]) -> _T: ... @overload