diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 7aca834f..2b9f57a1 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -13,6 +13,7 @@ import logging import os import re +import socket import textwrap import time import traceback @@ -525,3 +526,41 @@ def connect_standby(cls, **kwargs): kwargs ) return pg_connection.connect(**conn_spec, loop=cls.loop) + + +class InstrumentedServer: + """ + A socket server for testing. + It will write each item from `data`, and wait for the corresponding event + in `received_events` to notify that it was received before writing the next + item from `data`. + """ + def __init__(self, data, received_events): + assert len(data) == len(received_events) + self._data = data + self._server = None + self._received_events = received_events + + async def _handle_client(self, _reader, writer): + for datum, received_event in zip(self._data, self._received_events): + writer.write(datum) + await writer.drain() + await received_event.wait() + + writer.close() + await writer.wait_closed() + + async def start(self): + """Start the server.""" + self._server = await asyncio.start_server(self._handle_client, 'localhost', 0) + assert self._server.sockets + sock = self._server.sockets[0] + addr, port = sock.getsockname() + return { + 'host': addr, + 'port': port, + } + + def stop(self): + """Stop the server.""" + self._server.close() diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..5bdf6f97 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -714,6 +714,11 @@ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): self.ssl_is_advisory = ssl_is_advisory def data_received(self, data): + if self.on_data.done(): + # Only expect to receive one byte here; ignore unsolicited further + # data. + return + if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and diff --git a/tests/test_connect.py b/tests/test_connect.py index 5333e2c5..b6ebca63 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -7,6 +7,7 @@ import asyncio import contextlib +import copy import gc import ipaddress import os @@ -17,11 +18,13 @@ import stat import tempfile import textwrap +import time import unittest import unittest.mock import urllib.parse import warnings import weakref +from unittest import mock import asyncpg from asyncpg import _testbase as tb @@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self): await con.close() +class TestMisbehavingServer(tb.TestCase): + """Tests for client connection behaviour given a misbehaving server.""" + + async def test_tls_upgrade_extra_data_received(self): + data = [ + # First, the server writes b"S" to signal it is willing to perform + # SSL + b"S", + # Then, the server writes an unsolicted arbitrary byte afterwards + b"N", + ] + data_received_events = [asyncio.Event() for _ in data] + + # Patch out the loop's create_connection so we can instrument the proto + # we return. + old_create_conn = self.loop.create_connection + + async def _mock_create_conn(*args, **kwargs): + transport, proto = await old_create_conn(*args, **kwargs) + old_data_received = proto.data_received + + num_received = 0 + + def _data_received(*args, **kwargs): + nonlocal num_received + # Call the original data_received method + ret = old_data_received(*args, **kwargs) + # Fire the event to signal we've received this datum now. + data_received_events[num_received].set() + num_received += 1 + return ret + + proto.data_received = _data_received + + # To deterministically provoke the race we're interested in for + # this regression test, wait for all data to be received before + # returning from create_connection(). + await data_received_events[-1].wait() + return transport, proto + + server = tb.InstrumentedServer(data, data_received_events) + conn_spec = await server.start() + + # The call to connect() should raise a ConnectionResetError as the + # server will close the connection after writing all the data. + with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn), + self.assertRaises(ConnectionResetError)): + await pg_connection.connect(**conn_spec, ssl=True, loop=self.loop) + + server.stop() + + def _get_connected_host(con): peername = con._transport.get_extra_info('peername') if isinstance(peername, tuple):