Skip to content

Commit

Permalink
feat[Implement 'read_timeout' parameter aio-libs#974]: support read_t…
Browse files Browse the repository at this point in the history
…imeout in connect
  • Loading branch information
chenxl committed Sep 9, 2024
1 parent 83aa96e commit 15c2633
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
28 changes: 23 additions & 5 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def connect(host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None,
read_timeout=None):
"""See connections.Connection.__init__() for information about
defaults."""
coro = _connect(host=host, user=user, password=password, db=db,
Expand All @@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
read_default_group=read_default_group,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
auth_plugin=auth_plugin, program_name=program_name,
read_timeout=read_timeout)
return _ConnectionContextManager(coro)


Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, read_timeout=None):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -184,6 +186,8 @@ def __init__(self, host="localhost", user=None, password="",
handshaking with MySQL. (omitted by default)
:param server_public_key: SHA256 authentication plugin public
key value.
:param read_timeout: The timeout for reading from the connection in seconds
(default: None - no timeout)
:param loop: asyncio loop
"""
self._loop = loop or asyncio.get_event_loop()
Expand Down Expand Up @@ -257,6 +261,7 @@ def __init__(self, host="localhost", user=None, password="",

self.cursorclass = cursorclass
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout

self._result = None
self._affected_rows = 0
Expand Down Expand Up @@ -654,12 +659,25 @@ async def _read_packet(self, packet_type=MysqlPacket):

async def _read_bytes(self, num_bytes):
try:
data = await self._reader.readexactly(num_bytes)
if self.read_timeout:
try:
data = await asyncio.wait_for(
self._reader.readexactly(num_bytes),
self.read_timeout
)
except asyncio.TimeoutError as e:
raise asyncio.TimeoutError("Read timeout exceeded") from e
else:
data = await self._reader.readexactly(num_bytes)
except asyncio.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except OSError as e:
except (OSError, asyncio.TimeoutError) as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except Exception as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
Expand Down
7 changes: 7 additions & 0 deletions tests/sa/test_sa_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ async def connect(**kwargs):
return connect


@pytest.mark.run_loop
async def test_read_timeout(sa_connect):
conn = await sa_connect(read_timeout=0.01)
with pytest.raises(aiomysql.OperationalError):
await conn.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_execute_text_select(sa_connect):
conn = await sa_connect()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ async def test_connect_timeout(connection_creator):
await connection_creator(connect_timeout=0.000000000001)


@pytest.mark.run_loop
async def test_read_timeout(connection_creator):
with pytest.raises(aiomysql.OperationalError):
con = await connection_creator(read_timeout=0.01)
cur = await con.cursor()
await cur.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_config_file(fill_my_cnf, connection_creator, mysql_params):
tests_root = os.path.abspath(os.path.dirname(__file__))
Expand Down

0 comments on commit 15c2633

Please sign in to comment.