From 6eb780f58cbf56f01a19a16fd57a3ef87dc60fdf Mon Sep 17 00:00:00 2001 From: Vladimir Magamedov Date: Sun, 19 May 2024 21:35:48 +0300 Subject: [PATCH] Refactored server close logic to gracefully exit without using GOAWAY frames --- grpclib/client.py | 9 ++++--- grpclib/protocol.py | 5 ++++ grpclib/server.py | 56 +++++++++++++++++++++++++------------------- tests/stubs.py | 3 +++ tests/test_memory.py | 3 --- 5 files changed, 46 insertions(+), 30 deletions(-) diff --git a/grpclib/client.py b/grpclib/client.py index 2f788e6..0fe146c 100644 --- a/grpclib/client.py +++ b/grpclib/client.py @@ -62,7 +62,10 @@ class Handler(AbstractHandler): - connection_lost = False + closing = False + + def connection_made(self, connection: Any) -> None: + pass def accept(self, stream: Any, headers: Any, release_stream: Any) -> None: raise NotImplementedError('Client connection can not accept requests') @@ -71,7 +74,7 @@ def cancel(self, stream: Any) -> None: pass def close(self) -> None: - self.connection_lost = True + self.closing = True class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]): @@ -737,7 +740,7 @@ async def _create_connection(self) -> H2Protocol: @property def _connected(self) -> bool: return (self._protocol is not None - and not self._protocol.handler.connection_lost) + and not cast(Handler, self._protocol.handler).closing) async def __connect__(self) -> H2Protocol: if not self._connected: diff --git a/grpclib/protocol.py b/grpclib/protocol.py index 66f66c1..011eee6 100644 --- a/grpclib/protocol.py +++ b/grpclib/protocol.py @@ -488,6 +488,10 @@ def closable(self) -> bool: class AbstractHandler(ABC): + @abstractmethod + def connection_made(self, connection: Connection) -> None: + pass + @abstractmethod def accept( self, @@ -709,6 +713,7 @@ def connection_made(self, transport: BaseTransport) -> None: self.connection.flush() self.connection.initialize() + self.handler.connection_made(self.connection) self.processor = EventsProcessor(self.handler, self.connection) def data_received(self, data: bytes) -> None: diff --git a/grpclib/server.py b/grpclib/server.py index b652a75..cea3fc2 100644 --- a/grpclib/server.py +++ b/grpclib/server.py @@ -4,6 +4,7 @@ import logging import asyncio import warnings +from functools import partial from types import TracebackType from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast @@ -12,6 +13,7 @@ import h2.config import h2.exceptions +from h2.errors import ErrorCodes from multidict import MultiDict @@ -24,7 +26,7 @@ from .metadata import Deadline, encode_grpc_message, _Metadata from .metadata import encode_metadata, decode_metadata, _MetadataLike from .metadata import _STATUS_DETAILS_KEY, encode_bin_value -from .protocol import H2Protocol, AbstractHandler +from .protocol import H2Protocol, AbstractHandler, Connection from .exceptions import GRPCError, ProtocolError, StreamTerminatedError from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec @@ -493,9 +495,8 @@ def __gc_step__(self) -> None: self.__gc_collect__() -class Handler(_GC, AbstractHandler): - __gc_interval__ = 10 - +class Handler(AbstractHandler): + connection: Connection closing = False def __init__( @@ -511,13 +512,18 @@ def __init__( self.dispatch = dispatch self.loop = asyncio.get_event_loop() self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {} - self._cancelled: Set['asyncio.Task[None]'] = set() - def __gc_collect__(self) -> None: - self._tasks = {s: t for s, t in self._tasks.items() - if not t.done()} - self._cancelled = {t for t in self._cancelled - if not t.done()} + def connection_made(self, connection: Connection) -> None: + self.connection = connection + + def handler_done( + self, + stream: 'protocol.Stream', + _: asyncio.Future[None], + ) -> None: + self._tasks.pop(stream) + if self.closing and not self._tasks: + self.connection.close() def accept( self, @@ -525,30 +531,32 @@ def accept( headers: _Headers, release_stream: Callable[[], Any], ) -> None: - self.__gc_step__() - self._tasks[stream] = self.loop.create_task(request_handler( - self.mapping, stream, headers, self.codec, - self.status_details_codec, self.dispatch, release_stream, - )) + if self.closing: + stream.reset_nowait(ErrorCodes.REFUSED_STREAM) + release_stream() + else: + task = self._tasks[stream] = self.loop.create_task(request_handler( + self.mapping, stream, headers, self.codec, + self.status_details_codec, self.dispatch, release_stream, + )) + task.add_done_callback(partial(self.handler_done, stream)) def cancel(self, stream: 'protocol.Stream') -> None: - task = self._tasks.pop(stream) - task.cancel() - self._cancelled.add(task) + self._tasks[stream].cancel() def close(self) -> None: for task in self._tasks.values(): task.cancel() - self._cancelled.update(self._tasks.values()) self.closing = True async def wait_closed(self) -> None: - if self._cancelled: - await asyncio.wait(self._cancelled) + if self._tasks: + await asyncio.wait(self._tasks.values()) + else: + self.connection.close() def check_closed(self) -> bool: - self.__gc_collect__() - return not self._tasks and not self._cancelled + return not self._tasks class Server(_GC): @@ -737,11 +745,11 @@ async def wait_closed(self) -> None: if self._server is None or self._server_closed_fut is None: raise RuntimeError('Server is not started') await self._server_closed_fut - await self._server.wait_closed() if self._handlers: await asyncio.wait({ self._loop.create_task(h.wait_closed()) for h in self._handlers }) + await self._server.wait_closed() async def __aenter__(self) -> 'Server': return self diff --git a/tests/stubs.py b/tests/stubs.py index e2eb753..c4ac09d 100644 --- a/tests/stubs.py +++ b/tests/stubs.py @@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler): headers = None release_stream = None + def connection_made(self, connection): + pass + def accept(self, stream, headers, release_stream): self.stream = stream self.headers = headers diff --git a/tests/test_memory.py b/tests/test_memory.py index 9b37f3f..e40490f 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -83,14 +83,11 @@ async def test_stream(): cs = ClientServer(DummyService, DummyServiceStub) async with cs as (_, stub): await stub.UnaryUnary(DummyRequest(value='ping')) - handler = next(iter(cs.server._handlers)) - handler.__gc_collect__() gc.collect() gc.disable() try: pre = set(collect()) await stub.UnaryUnary(DummyRequest(value='ping')) - handler.__gc_collect__() post = collect() diff = set(post).difference(pre)