diff --git a/nats/aio/client.py b/nats/aio/client.py index 46b19638..3d4bd51a 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -53,7 +53,7 @@ DEFAULT_SUB_PENDING_MSGS_LIMIT, Subscription, ) -from .transport import TcpTransport, Transport, WebSocketTransport +from .transport import Transport, TcpTransport, WebSocketTransport __version__ = '2.2.0' __lang__ = 'python3' @@ -1271,32 +1271,24 @@ async def _select_next_server(self) -> None: # Not yet exceeded max_reconnect_attempts so can still use # this server in the future. self._server_pool.append(s) - if s.last_attempt is not None and now < s.last_attempt + self.options[ - "reconnect_time_wait"]: + delay = self.options["reconnect_time_wait"] + if s.last_attempt is not None and now < s.last_attempt + delay: # Backoff connecting to server if we attempted recently. - await asyncio.sleep(self.options["reconnect_time_wait"]) + await asyncio.sleep(delay) try: s.last_attempt = time.monotonic() - if not self._transport: - if s.uri.scheme in ("ws", "wss"): - self._transport = WebSocketTransport() - else: - # use TcpTransport as a fallback - self._transport = TcpTransport() - if s.uri.scheme == "wss": - # wss is expected to connect directly with tls - await self._transport.connect_tls( - s.uri, - ssl_context=self.ssl_context, - buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] - ) + transport_class: type[Transport] + if s.uri.scheme in ("ws", "wss"): + transport_class = WebSocketTransport else: - await self._transport.connect( - s.uri, - buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] - ) + transport_class = TcpTransport + ssl_context = self.ssl_context if s.uri.scheme == "wss" else None + self._transport = await transport_class.connect( + s.uri, + DEFAULT_BUFFER_SIZE, + self.options['connect_timeout'], + ssl_context, + ) self._current_server = s break except Exception as e: @@ -1885,7 +1877,6 @@ async def _process_connect_init(self) -> None: await self._transport.connect_tls( hostname, self.ssl_context, - DEFAULT_BUFFER_SIZE, self.options['connect_timeout'], ) diff --git a/nats/aio/transport.py b/nats/aio/transport.py index e5952eaf..d1d7c464 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -13,10 +13,12 @@ class Transport(abc.ABC): + @classmethod @abc.abstractmethod async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + cls, uri: ParseResult, buffer_size: int, connect_timeout: int, + ssl_context: ssl.SSLContext | None + ) -> Transport: """ Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be obtained calling urllib.parse.urlparse. @@ -28,7 +30,6 @@ async def connect_tls( self, uri: str | ParseResult, ssl_context: ssl.SSLContext, - buffer_size: int, connect_timeout: int, ): """ @@ -39,14 +40,14 @@ async def connect_tls( pass @abc.abstractmethod - def write(self, payload: bytes): + def write(self, payload: bytes) -> None: """ Write bytes to underlying transport. Needs a call to drain() to be successfully written. """ pass @abc.abstractmethod - def writelines(self, payload: list[bytes]): + def writelines(self, payload: list[bytes]) -> None: """ Writes a list of bytes, one by one, to the underlying transport. Needs a call to drain() to be successfully written. @@ -69,21 +70,21 @@ async def readline(self) -> bytes: pass @abc.abstractmethod - async def drain(self): + async def drain(self) -> None: """ Flushes the bytes queued for transmission when calling write() and writelines(). """ pass @abc.abstractmethod - async def wait_closed(self): + async def wait_closed(self) -> None: """ Waits until the connection is successfully closed. """ pass @abc.abstractmethod - def close(self): + def close(self) -> None: """ Closes the underlying transport. """ @@ -97,7 +98,7 @@ def at_eof(self) -> bool: pass @abc.abstractmethod - def __bool__(self): + def __bool__(self) -> bool: """ Returns if the transport was initialized, either by calling connect of connect_tls. """ @@ -106,15 +107,27 @@ def __bool__(self): class TcpTransport(Transport): - def __init__(self): - self._bare_io_reader: asyncio.StreamReader | None = None - self._io_reader: asyncio.StreamReader | None = None - self._bare_io_writer: asyncio.StreamWriter | None = None - self._io_writer: asyncio.StreamWriter | None = None + def __init__( + self, r: asyncio.StreamReader, w: asyncio.StreamWriter + ) -> None: + self._io_reader: asyncio.StreamReader = r + self._io_writer: asyncio.StreamWriter = w + # We keep a reference to the initial transport we used when + # establishing the connection in case we later upgrade to TLS + # after getting the first INFO message. This is in order to + # prevent the GC closing the socket after we send CONNECT + # and replace the transport. + # + # See https://github.com/nats-io/asyncio-nats/issues/43 + self._bare_io_reader: asyncio.StreamReader = r + self._bare_io_writer: asyncio.StreamWriter = w + + @classmethod async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + cls, uri: ParseResult, buffer_size: int, connect_timeout: int, + ssl_context: ssl.SSLContext | None + ) -> TcpTransport: r, w = await asyncio.wait_for( asyncio.open_connection( host=uri.hostname, @@ -122,25 +135,21 @@ async def connect( limit=buffer_size, ), connect_timeout ) - # We keep a reference to the initial transport we used when - # establishing the connection in case we later upgrade to TLS - # after getting the first INFO message. This is in order to - # prevent the GC closing the socket after we send CONNECT - # and replace the transport. - # - # See https://github.com/nats-io/asyncio-nats/issues/43 - self._bare_io_reader = self._io_reader = r - self._bare_io_writer = self._io_writer = w + transport = cls(r, w) + if ssl_context is not None: + await transport.connect_tls( + uri=uri, + ssl_context=ssl_context, + connect_timeout=connect_timeout, + ) + return transport async def connect_tls( self, uri: str | ParseResult, ssl_context: ssl.SSLContext, - buffer_size: int, connect_timeout: int, ) -> None: - assert self._io_writer, f'{type(self).__name__}.connect must be called first' - # manually recreate the stream reader/writer with a tls upgraded transport reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) @@ -157,60 +166,65 @@ async def connect_tls( ) self._io_reader, self._io_writer = reader, writer - def write(self, payload): - return self._io_writer.write(payload) + def write(self, payload: bytes) -> None: + self._io_writer.write(payload) - def writelines(self, payload): - return self._io_writer.writelines(payload) + def writelines(self, payload: list[bytes]) -> None: + self._io_writer.writelines(payload) - async def read(self, buffer_size: int): - assert self._io_reader, f'{type(self).__name__}.connect must be called first' + async def read(self, buffer_size: int) -> bytes: return await self._io_reader.read(buffer_size) - async def readline(self): + async def readline(self) -> bytes: return await self._io_reader.readline() - async def drain(self): - return await self._io_writer.drain() + async def drain(self) -> None: + await self._io_writer.drain() - async def wait_closed(self): + async def wait_closed(self) -> None: return await self._io_writer.wait_closed() - def close(self): + def close(self) -> None: return self._io_writer.close() - def at_eof(self): + def at_eof(self) -> bool: return self._io_reader.at_eof() - def __bool__(self): + def __bool__(self) -> bool: return bool(self._io_writer) and bool(self._io_reader) class WebSocketTransport(Transport): - def __init__(self): + def __init__( + self, ws: aiohttp.ClientWebSocketResponse, + client: aiohttp.ClientSession + ): + self._ws = ws + self._client = client + self._pending: asyncio.Queue[bytes] = asyncio.Queue() + self._close_task: asyncio.Future[bool] = asyncio.Future() + + @classmethod + async def connect( + cls, uri: ParseResult, buffer_size: int, connect_timeout: int, + ssl_context: ssl.SSLContext | None + ) -> WebSocketTransport: if not aiohttp: raise ImportError( "Could not import aiohttp transport, please install it with `pip install aiohttp`" ) - self._ws: aiohttp.ClientWebSocketResponse | None = None - self._client: aiohttp.ClientSession = aiohttp.ClientSession() - self._pending = asyncio.Queue() - self._close_task = asyncio.Future() - - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + client = aiohttp.ClientSession() # for websocket library, the uri must contain the scheme already - self._ws = await self._client.ws_connect( - uri.geturl(), timeout=connect_timeout + ws = await client.ws_connect( + uri.geturl(), timeout=connect_timeout, ssl=ssl_context ) + return cls(ws, client) async def connect_tls( self, uri: str | ParseResult, ssl_context: ssl.SSLContext, - buffer_size: int, connect_timeout: int, ): self._ws = await self._client.ws_connect( @@ -219,39 +233,38 @@ async def connect_tls( timeout=connect_timeout ) - def write(self, payload): + def write(self, payload: bytes) -> None: self._pending.put_nowait(payload) - def writelines(self, payload): + def writelines(self, payload: list[bytes]) -> None: for message in payload: self.write(message) - async def read(self, buffer_size: int): + async def read(self, buffer_size: int) -> bytes: return await self.readline() - async def readline(self): + async def readline(self) -> bytes: data = await self._ws.receive() if data.type == aiohttp.WSMsgType.CLOSE: # if the connection terminated abruptly, return empty binary data to raise unexpected EOF return b'' return data.data - async def drain(self): + async def drain(self) -> None: # send all the messages pending while not self._pending.empty(): message = self._pending.get_nowait() await self._ws.send_bytes(message) - async def wait_closed(self): + async def wait_closed(self) -> None: await self._close_task await self._client.close() - self._ws = self._client = None - def close(self): + def close(self) -> None: self._close_task = asyncio.create_task(self._ws.close()) - def at_eof(self): + def at_eof(self) -> bool: return self._ws.closed - def __bool__(self): + def __bool__(self) -> bool: return bool(self._client)