Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 'read_timeout' parameter #974

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def connect(host="localhost", user=None, password="",
read_default_file=None, conv=decoders, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
read_timeout=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
Expand All @@ -64,6 +65,7 @@ def connect(host="localhost", user=None, password="",
init_command=init_command,
connect_timeout=connect_timeout,
read_default_group=read_default_group,
read_timeout=read_timeout,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(self, host="localhost", user=None, password="",
charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
connect_timeout=None, read_default_group=None, read_timeout=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
Expand Down Expand Up @@ -171,6 +173,8 @@ def __init__(self, host="localhost", user=None, password="",
when connecting.
:param read_default_group: Group to read from in the configuration
file.
:param read_timeout: The timeout for reading from the connection in seconds
(default: None - no timeout)
:param autocommit: Autocommit mode. None means use server default.
(default: False)
:param local_infile: boolean to enable the use of LOAD DATA LOCAL
Expand Down Expand Up @@ -257,6 +261,7 @@ def __init__(self, host="localhost", user=None, password="",

self.cursorclass = cursorclass
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout

self._result = None
self._affected_rows = 0
Expand Down Expand Up @@ -654,12 +659,25 @@ async def _read_packet(self, packet_type=MysqlPacket):

async def _read_bytes(self, num_bytes):
try:
data = await self._reader.readexactly(num_bytes)
if self.read_timeout:
try:
data = await asyncio.wait_for(
self._reader.readexactly(num_bytes),
self.read_timeout
)
except asyncio.TimeoutError as e:
raise asyncio.TimeoutError("Read timeout exceeded") from e
else:
data = await self._reader.readexactly(num_bytes)
except asyncio.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except OSError as e:
except (OSError, asyncio.TimeoutError) as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except Exception as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
Expand Down
13 changes: 10 additions & 3 deletions tests/sa/test_sa_connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from unittest import mock

import aiomysql

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'aiomysql' is imported with both 'import' and 'import from'.
from aiomysql import sa, Cursor

import pytest
from sqlalchemy import MetaData, Table, Column, Integer, String, func, select
from sqlalchemy.schema import DropTable, CreateTable
from sqlalchemy.sql.expression import bindparam

import aiomysql
from aiomysql import sa, Cursor

meta = MetaData()
tbl = Table('sa_tbl', meta,
Column('id', Integer, nullable=False,
Expand Down Expand Up @@ -35,6 +35,13 @@
return connect


@pytest.mark.run_loop
async def test_read_timeout(sa_connect):
conn = await sa_connect(read_timeout=0.01)
with pytest.raises(aiomysql.OperationalError):
await conn.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_execute_text_select(sa_connect):
conn = await sa_connect()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ async def test_connect_timeout(connection_creator):
await connection_creator(connect_timeout=0.000000000001)


@pytest.mark.run_loop
async def test_read_timeout(connection_creator):
with pytest.raises(aiomysql.OperationalError):
con = await connection_creator(read_timeout=0.01)
cur = await con.cursor()
await cur.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_config_file(fill_my_cnf, connection_creator, mysql_params):
tests_root = os.path.abspath(os.path.dirname(__file__))
Expand Down
Loading