From c9515b32c6b147ae7263efb05f3a8fa07f10e88a Mon Sep 17 00:00:00 2001 From: Aliaksandr Akulchyk Date: Fri, 5 Jan 2024 17:30:28 +0100 Subject: [PATCH] Implement 'read_timeout' parameter --- aiomysql/connection.py | 24 +++++++++++++++++++++--- tests/sa/test_sa_connection.py | 13 ++++++++++--- tests/test_connection.py | 8 ++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 3520dfcc..42d622c1 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -51,6 +51,7 @@ def connect(host="localhost", user=None, password="", read_default_file=None, conv=decoders, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, read_default_group=None, + read_timeout=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', program_name='', server_public_key=None): @@ -64,6 +65,7 @@ def connect(host="localhost", user=None, password="", init_command=init_command, connect_timeout=connect_timeout, read_default_group=read_default_group, + read_timeout=read_timeout, autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, auth_plugin=auth_plugin, program_name=program_name) @@ -139,7 +141,7 @@ def __init__(self, host="localhost", user=None, password="", charset='', sql_mode=None, read_default_file=None, conv=decoders, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, - connect_timeout=None, read_default_group=None, + connect_timeout=None, read_default_group=None, read_timeout=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', program_name='', server_public_key=None): @@ -171,6 +173,8 @@ def __init__(self, host="localhost", user=None, password="", when connecting. :param read_default_group: Group to read from in the configuration file. + :param read_timeout: The timeout for reading from the connection in seconds + (default: None - no timeout) :param autocommit: Autocommit mode. None means use server default. (default: False) :param local_infile: boolean to enable the use of LOAD DATA LOCAL @@ -257,6 +261,7 @@ def __init__(self, host="localhost", user=None, password="", self.cursorclass = cursorclass self.connect_timeout = connect_timeout + self.read_timeout = read_timeout self._result = None self._affected_rows = 0 @@ -654,12 +659,25 @@ async def _read_packet(self, packet_type=MysqlPacket): async def _read_bytes(self, num_bytes): try: - data = await self._reader.readexactly(num_bytes) + if self.read_timeout: + try: + data = await asyncio.wait_for( + self._reader.readexactly(num_bytes), + self.read_timeout + ) + except asyncio.TimeoutError as e: + raise asyncio.TimeoutError("Read timeout exceeded") from e + else: + data = await self._reader.readexactly(num_bytes) except asyncio.IncompleteReadError as e: msg = "Lost connection to MySQL server during query" self.close() raise OperationalError(CR.CR_SERVER_LOST, msg) from e - except OSError as e: + except (OSError, asyncio.TimeoutError) as e: + msg = f"Lost connection to MySQL server during query ({e})" + self.close() + raise OperationalError(CR.CR_SERVER_LOST, msg) from e + except Exception as e: msg = f"Lost connection to MySQL server during query ({e})" self.close() raise OperationalError(CR.CR_SERVER_LOST, msg) from e diff --git a/tests/sa/test_sa_connection.py b/tests/sa/test_sa_connection.py index a68e9032..7897b212 100644 --- a/tests/sa/test_sa_connection.py +++ b/tests/sa/test_sa_connection.py @@ -1,13 +1,13 @@ from unittest import mock +import aiomysql +from aiomysql import sa, Cursor + import pytest from sqlalchemy import MetaData, Table, Column, Integer, String, func, select from sqlalchemy.schema import DropTable, CreateTable from sqlalchemy.sql.expression import bindparam -import aiomysql -from aiomysql import sa, Cursor - meta = MetaData() tbl = Table('sa_tbl', meta, Column('id', Integer, nullable=False, @@ -35,6 +35,13 @@ async def connect(**kwargs): return connect +@pytest.mark.run_loop +async def test_read_timeout(sa_connect): + conn = await sa_connect(read_timeout=0.01) + with pytest.raises(aiomysql.OperationalError): + await conn.execute("DO SLEEP(1)") + + @pytest.mark.run_loop async def test_execute_text_select(sa_connect): conn = await sa_connect() diff --git a/tests/test_connection.py b/tests/test_connection.py index c0c1be3d..3e07795b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,6 +32,14 @@ async def test_connect_timeout(connection_creator): await connection_creator(connect_timeout=0.000000000001) +@pytest.mark.run_loop +async def test_read_timeout(connection_creator): + with pytest.raises(aiomysql.OperationalError): + con = await connection_creator(read_timeout=0.01) + cur = await con.cursor() + await cur.execute("DO SLEEP(1)") + + @pytest.mark.run_loop async def test_config_file(fill_my_cnf, connection_creator, mysql_params): tests_root = os.path.abspath(os.path.dirname(__file__))