Skip to content

Commit

Permalink
TLSUpgradeProto: don't set multiple results for an event
Browse files Browse the repository at this point in the history
In the case of a misbehaving server, the client may receive more than
one byte in separate data_received() invocations from the server. While
we can't do much sane with this, we should handle it gracefully and not
crash with asyncio.InvalidStateError when trying to set another result
on the event.

Fixes #729
  • Loading branch information
w-miller committed Feb 1, 2024
1 parent c2c8d20 commit 50bb192
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
40 changes: 40 additions & 0 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import re
import socket
import textwrap
import time
import traceback
Expand Down Expand Up @@ -525,3 +526,42 @@ def connect_standby(cls, **kwargs):
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)


class InstrumentedServer:
"""
A socket server for testing.
It will write each item from `data`, and wait for the corresponding event
in `received_events` to notify that it was received before writing the next
item from `data`.
"""
def __init__(self, data, received_events):
assert len(data) == len(received_events)
self._data = data
self._server = None
self._received_events = received_events

async def _handle_client(self, _reader, writer):
for datum, received_event in zip(self._data, self._received_events):
writer.write(datum)
await writer.drain()
await received_event.wait()

writer.close()
await writer.wait_closed()

async def start(self):
"""Start the server."""
self._server = await asyncio.start_server(self._handle_client, 'localhost', 0)
assert self._server.sockets
sock = self._server.sockets[0]
# Account for IPv4 and IPv6
addr, port = sock.getsockname()[:2]
return {
'host': addr,
'port': port,
}

def stop(self):
"""Stop the server."""
self._server.close()
5 changes: 5 additions & 0 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,11 @@ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
self.ssl_is_advisory = ssl_is_advisory

def data_received(self, data):
if self.on_data.done():
# Only expect to receive one byte here; ignore unsolicited further
# data.
return

if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
Expand Down
55 changes: 55 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import contextlib
import copy
import gc
import ipaddress
import os
Expand All @@ -17,11 +18,13 @@
import stat
import tempfile
import textwrap
import time
import unittest
import unittest.mock
import urllib.parse
import warnings
import weakref
from unittest import mock

import asyncpg
from asyncpg import _testbase as tb
Expand Down Expand Up @@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self):
await con.close()


class TestMisbehavingServer(tb.TestCase):
"""Tests for client connection behaviour given a misbehaving server."""

async def test_tls_upgrade_extra_data_received(self):
data = [
# First, the server writes b"S" to signal it is willing to perform
# SSL
b"S",
# Then, the server writes an unsolicted arbitrary byte afterwards
b"N",
]
data_received_events = [asyncio.Event() for _ in data]

# Patch out the loop's create_connection so we can instrument the proto
# we return.
old_create_conn = self.loop.create_connection

async def _mock_create_conn(*args, **kwargs):
transport, proto = await old_create_conn(*args, **kwargs)
old_data_received = proto.data_received

num_received = 0

def _data_received(*args, **kwargs):
nonlocal num_received
# Call the original data_received method
ret = old_data_received(*args, **kwargs)
# Fire the event to signal we've received this datum now.
data_received_events[num_received].set()
num_received += 1
return ret

proto.data_received = _data_received

# To deterministically provoke the race we're interested in for
# this regression test, wait for all data to be received before
# returning from create_connection().
await data_received_events[-1].wait()
return transport, proto

server = tb.InstrumentedServer(data, data_received_events)
conn_spec = await server.start()

# The call to connect() should raise a ConnectionResetError as the
# server will close the connection after writing all the data.
with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn),
self.assertRaises(ConnectionResetError)):
await pg_connection.connect(**conn_spec, ssl=True, loop=self.loop)

server.stop()


def _get_connected_host(con):
peername = con._transport.get_extra_info('peername')
if isinstance(peername, tuple):
Expand Down

0 comments on commit 50bb192

Please sign in to comment.