From 19ab680e3e14b0c741a089f30ddab792b02332dc Mon Sep 17 00:00:00 2001 From: Terry Cain Date: Sun, 8 Jul 2018 17:37:43 +0100 Subject: [PATCH 1/5] Initial SHA256 Implementation --- aiomysql/connection.py | 157 +++++++++++++++++++++++++++++++++++++++-- docs/connection.rst | 4 +- tests/conftest.py | 18 ++--- 3 files changed, 165 insertions(+), 14 deletions(-) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 3f4b9e04..a6c1e581 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -41,7 +41,7 @@ # from aiomysql.utils import _convert_to_str from .cursors import Cursor from .utils import _ConnectionContextManager, _ContextManager -# from .log import logger +from .log import logger DEFAULT_USER = getpass.getuser() @@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, no_delay=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name=''): + program_name='', server_public_key=None): """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, @@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, no_delay=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name=''): + program_name='', server_public_key=None): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="", (default: Server Default) :param program_name: Program name string to provide when handshaking with MySQL. (default: sys.argv[0]) + :param server_public_key: SHA256 authentication plugin public + key value. :param loop: asyncio loop """ self._loop = loop or asyncio.get_event_loop() @@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="", self._client_auth_plugin = auth_plugin self._server_auth_plugin = "" self._auth_plugin_used = "" + self.server_public_key = server_public_key + self.salt = None # TODO somehow import version from __init__.py self._connect_attrs = { @@ -711,6 +715,20 @@ async def _request_authentication(self): if auth_plugin in ('', 'mysql_native_password'): authresp = _auth.scramble_native_password( self._password.encode('latin1'), self.salt) + elif auth_plugin == 'caching_sha2_password': + if self._password: + authresp = _auth.scramble_caching_sha2( + self._password.encode('latin1'), self.salt + ) + # Else: empty password + elif auth_plugin == 'sha256_password': + if self._ssl_context and self.server_capabilities & CLIENT.SSL: + authresp = self._password.encode('latin1') + b'\0' + elif self._password: + authresp = b'\1' # request public key + else: + authresp = b'\0' # empty password + elif auth_plugin in ('', 'mysql_clear_password'): authresp = self._password.encode('latin1') + b'\0' @@ -767,9 +785,21 @@ async def _request_authentication(self): auth_packet.read_all()) + b'\0' self.write_packet(data) await self._read_packet() + elif auth_packet.is_extra_auth_data(): + if auth_plugin == "caching_sha2_password": + await self.caching_sha2_password_auth(auth_packet) + elif auth_plugin == "sha256_password": + await self.sha256_password_auth(auth_packet) + else: + raise OperationalError("Received extra packet " + "for auth method %r", auth_plugin) async def _process_auth(self, plugin_name, auth_packet): - if plugin_name == b"mysql_native_password": + if plugin_name == b"caching_sha2_password": + return self.caching_sha2_password_auth(auth_packet) + elif plugin_name == b"sha256_password": + return self.sha256_password_auth(auth_packet) + elif plugin_name == b"mysql_native_password": # https://dev.mysql.com/doc/internals/en/ # secure-password-authentication.html#packet-Authentication:: # Native41 @@ -798,6 +828,125 @@ async def _process_auth(self, plugin_name, auth_packet): return pkt + async def caching_sha2_password_auth(self, pkt): + # No password fast path + if not self._password: + self.write_packet(b'') + pkt = await self._read_packet() + pkt.check_error() + return pkt + + if pkt.is_auth_switch_request(): + # Try from fast auth + logger.debug("caching sha2: Trying fast path") + self.salt = pkt.read_all() + scrambled = _auth.scramble_caching_sha2( + self._password.encode('latin1'), self.salt + ) + + self.write_packet(scrambled) + pkt = await self._read_packet() + pkt.check_error() + + # else: fast auth is tried in initial handshake + + if not pkt.is_extra_auth_data(): + raise OperationalError( + "caching sha2: Unknown packet " + "for fast auth: {0}".format(pkt._data[:1]) + ) + + # magic numbers: + # 2 - request public key + # 3 - fast auth succeeded + # 4 - need full auth + + pkt.advance(1) + n = pkt.read_uint8() + + if n == 3: + logger.debug("caching sha2: succeeded by fast path.") + pkt = await self._read_packet() + pkt.check_error() # pkt must be OK packet + return pkt + + if n != 4: + raise OperationalError("caching sha2: Unknown " + "result for fast auth: {0}".format(n)) + + logger.debug("caching sha2: Trying full auth...") + + if self._ssl_context: + logger.debug("caching sha2: Sending plain " + "password via secure connection") + self.write_packet(self._password.encode('latin1') + b'\0') + pkt = await self._read_packet() + pkt.check_error() + return pkt + + if not self.server_public_key: + self.write_packet(b'\x02') + pkt = await self._read_packet() # Request public key + pkt.check_error() + + if not pkt.is_extra_auth_data(): + raise OperationalError( + "caching sha2: Unknown packet " + "for public key: {0}".format(pkt._data[:1]) + ) + + self.server_public_key = pkt._data[1:] + logger.debug(self.server_public_key.decode('ascii')) + + data = _auth.sha2_rsa_encrypt( + self._password.encode('latin1'), self.salt, + self.server_public_key + ) + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() + + async def sha256_password_auth(self, pkt): + if self._ssl_context: + logger.debug("sha256: Sending plain password") + data = self._password.encode('latin1') + b'\0' + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() + return pkt + + if pkt.is_auth_switch_request(): + self.salt = pkt.read_all() + if not self.server_public_key and self._password: + # Request server public key + logger.debug("sha256: Requesting server public key") + self.write_packet(b'\1') + pkt = await self._read_packet() + pkt.check_error() + + if pkt.is_extra_auth_data(): + self.server_public_key = pkt._data[1:] + logger.debug( + "Received public key:\n", + self.server_public_key.decode('ascii') + ) + + if self._password: + if not self.server_public_key: + raise OperationalError("Couldn't receive server's public key") + + data = _auth.sha2_rsa_encrypt( + self._password.encode('latin1'), self.salt, + self.server_public_key + ) + else: + data = b'' + + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() + return pkt + # _mysql support def thread_id(self): return self.server_thread_id[0] diff --git a/docs/connection.rst b/docs/connection.rst index 604a0bc3..de3dc0c8 100644 --- a/docs/connection.rst +++ b/docs/connection.rst @@ -47,7 +47,8 @@ Example:: client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, read_default_group=None, no_delay=False, autocommit=False, echo=False, - ssl=None, auth_plugin='', program_name='', loop=None) + ssl=None, auth_plugin='', program_name='', + server_public_key=None, loop=None) A :ref:`coroutine ` that connects to MySQL. @@ -89,6 +90,7 @@ Example:: (default: Server Default) :param program_name: Program name string to provide when handshaking with MySQL. (default: sys.argv[0]) + :param server_public_key: SHA256 authenticaiton plugin public key value. :param loop: asyncio event loop instance or ``None`` for default one. :returns: :class:`Connection` instance. diff --git a/tests/conftest.py b/tests/conftest.py index 4a9b9bbb..ff2a59f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,15 +35,15 @@ def pytest_generate_tests(metafunc): loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio'] metafunc.parametrize("loop_type", loop_type) - # if 'mysql_tag' in metafunc.fixturenames: - # tags = set(metafunc.config.option.mysql_tag) - # if not tags: - # tags = ['5.7'] - # elif 'all' in tags: - # tags = ['5.6', '5.7', '8.0'] - # else: - # tags = list(tags) - # metafunc.parametrize("mysql_tag", tags, scope='session') + if 'mysql_tag' in metafunc.fixturenames: + # tags = set(metafunc.config.option.mysql_tag) + # if not tags: + # tags = ['5.7'] + # elif 'all' in tags: + # tags = ['5.6', '5.7', '8.0'] + # else: + # tags = list(tags) + metafunc.parametrize("mysql_tag", ['5.6', '8.0'], scope='session') # This is here unless someone fixes the generate_tests bit From 7c3bfbf51c6284b03c35fb15ea7e71adb15d7368 Mon Sep 17 00:00:00 2001 From: Terry Cain Date: Sun, 8 Jul 2018 19:23:25 +0100 Subject: [PATCH 2/5] Narrowed LIKE scope as was picking up unrelated rows Fixed unawaited authentication plugin coros --- aiomysql/connection.py | 58 +++++++++++++++++++++++------------------ examples/example_ssl.py | 38 +++++++++++++++++++++++++++ tests/conftest.py | 18 ++++++------- tests/test_ssl.py | 2 +- 4 files changed, 80 insertions(+), 36 deletions(-) create mode 100644 examples/example_ssl.py diff --git a/aiomysql/connection.py b/aiomysql/connection.py index a6c1e581..0fa69bd2 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -795,38 +795,44 @@ async def _request_authentication(self): "for auth method %r", auth_plugin) async def _process_auth(self, plugin_name, auth_packet): + # These auth plugins do their own packet handling if plugin_name == b"caching_sha2_password": - return self.caching_sha2_password_auth(auth_packet) + await self.caching_sha2_password_auth(auth_packet) elif plugin_name == b"sha256_password": - return self.sha256_password_auth(auth_packet) - elif plugin_name == b"mysql_native_password": - # https://dev.mysql.com/doc/internals/en/ - # secure-password-authentication.html#packet-Authentication:: - # Native41 - data = _auth.scramble_native_password( - self._password.encode('latin1'), - auth_packet.read_all()) - elif plugin_name == b"mysql_old_password": - # https://dev.mysql.com/doc/internals/en/ - # old-password-authentication.html - data = _auth.scramble_old_password(self._password.encode('latin1'), - auth_packet.read_all()) + b'\0' - elif plugin_name == b"mysql_clear_password": - # https://dev.mysql.com/doc/internals/en/ - # clear-text-authentication.html - data = self._password.encode('latin1') + b'\0' + await self.sha256_password_auth(auth_packet) else: - raise OperationalError( - 2059, "Authentication plugin '%s' not configured" % plugin_name - ) - self.write_packet(data) - pkt = await self._read_packet() - pkt.check_error() + if plugin_name == b"mysql_native_password": + # https://dev.mysql.com/doc/internals/en/ + # secure-password-authentication.html#packet-Authentication:: + # Native41 + data = _auth.scramble_native_password( + self._password.encode('latin1'), + auth_packet.read_all()) + elif plugin_name == b"mysql_old_password": + # https://dev.mysql.com/doc/internals/en/ + # old-password-authentication.html + data = _auth.scramble_old_password( + self._password.encode('latin1'), + auth_packet.read_all() + ) + b'\0' + elif plugin_name == b"mysql_clear_password": + # https://dev.mysql.com/doc/internals/en/ + # clear-text-authentication.html + data = self._password.encode('latin1') + b'\0' + else: + raise OperationalError( + 2059, "Authentication plugin '{0}'" + " not configured".format(plugin_name) + ) + + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name - return pkt + return pkt async def caching_sha2_password_auth(self, pkt): # No password fast path diff --git a/examples/example_ssl.py b/examples/example_ssl.py new file mode 100644 index 00000000..e66c267d --- /dev/null +++ b/examples/example_ssl.py @@ -0,0 +1,38 @@ +import asyncio +import ssl +import aiomysql + +ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +ctx.check_hostname = False +ctx.load_verify_locations(cafile='../tests/ssl_resources/ssl/ca.pem') + + +async def main(): + async with aiomysql.create_pool( + host='localhost', port=3306, user='root', + password='rootpw', ssl=ctx, + auth_plugin='mysql_clear_password') as pool: + + async with pool.get() as conn: + async with conn.cursor() as cur: + # Run simple command + await cur.execute("SHOW DATABASES;") + value = await cur.fetchall() + + values = [item[0] for item in value] + # Spot check the answers, we should at least have mysql + # and information_schema + assert 'mysql' in values, \ + 'Could not find the "mysql" table' + assert 'information_schema' in values, \ + 'Could not find the "mysql" table' + + # Check TLS variables + await cur.execute("SHOW STATUS LIKE 'Ssl_version%';") + value = await cur.fetchone() + + # The context has TLS + assert value[1].startswith('TLS'), \ + 'Not connected to the database with TLS' + +asyncio.get_event_loop().run_until_complete(main()) diff --git a/tests/conftest.py b/tests/conftest.py index ff2a59f8..5fe3f93a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,14 +36,14 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("loop_type", loop_type) if 'mysql_tag' in metafunc.fixturenames: - # tags = set(metafunc.config.option.mysql_tag) - # if not tags: - # tags = ['5.7'] - # elif 'all' in tags: - # tags = ['5.6', '5.7', '8.0'] - # else: - # tags = list(tags) - metafunc.parametrize("mysql_tag", ['5.6', '8.0'], scope='session') + tags = set(metafunc.config.option.mysql_tag) + if not tags: + tags = ['5.6', '8.0'] + elif 'all' in tags: + tags = ['5.6', '5.7', '8.0'] + else: + tags = list(tags) + metafunc.parametrize("mysql_tag", tags, scope='session') # This is here unless someone fixes the generate_tests bit @@ -288,7 +288,7 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request): assert result['have_ssl'] == "YES", \ "SSL Not Enabled on docker'd MySQL" - cursor.execute("SHOW STATUS LIKE '%Ssl_version%'") + cursor.execute("SHOW STATUS LIKE 'Ssl_version%'") result = cursor.fetchone() # As we connected with TLS, it should start with that :D diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 07c8ef61..88c5c5c9 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -22,7 +22,7 @@ async def test_tls_connect(mysql_server, loop): 'Could not find the "mysql" table' # Check TLS variables - await cur.execute("SHOW STATUS LIKE '%Ssl_version%';") + await cur.execute("SHOW STATUS LIKE 'Ssl_version%';") value = await cur.fetchone() # The context has TLS From b9b7b9a23594afc5091c07f18663e56e548138a4 Mon Sep 17 00:00:00 2001 From: Terry Cain Date: Sun, 8 Jul 2018 19:55:03 +0100 Subject: [PATCH 3/5] Fixed testcase asserting default auth plugin name Fixed bug where used auth plugin was not updated --- aiomysql/connection.py | 2 ++ tests/test_ssl.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 0fa69bd2..ef99d483 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -798,8 +798,10 @@ async def _process_auth(self, plugin_name, auth_packet): # These auth plugins do their own packet handling if plugin_name == b"caching_sha2_password": await self.caching_sha2_password_auth(auth_packet) + self._auth_plugin_used = plugin_name elif plugin_name == b"sha256_password": await self.sha256_password_auth(auth_packet) + self._auth_plugin_used = plugin_name else: if plugin_name == b"mysql_native_password": diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 88c5c5c9..044d759e 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -44,9 +44,15 @@ async def test_auth_plugin_renegotiation(mysql_server, loop): assert len(value), 'No databases found' + # Check we tried to use the cleartext plugin assert conn._client_auth_plugin == 'mysql_clear_password', \ 'Client did not try clear password auth' - assert conn._server_auth_plugin == 'mysql_native_password', \ + + # Check the server asked us to use MySQL's default plugin + assert conn._server_auth_plugin in ( + 'mysql_native_password', 'caching_sha2_password'), \ 'Server did not ask for native auth' - assert conn._auth_plugin_used == b'mysql_native_password', \ - 'Client did not renegotiate with native auth' + # Check we actually used the servers default plugin + assert conn._auth_plugin_used in ( + b'mysql_native_password', b'caching_sha2_password'), \ + 'Client did not renegotiate with server\'s default auth' From 6532b0d1e50c8c92b2d68ed690d664e6d7b922b6 Mon Sep 17 00:00:00 2001 From: Terry Cain Date: Sun, 8 Jul 2018 22:33:25 +0100 Subject: [PATCH 4/5] Normalised type of _auth_plugin_used. Added SHA265 tests --- aiomysql/connection.py | 6 +++--- tests/conftest.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_sha_connection.py | 25 +++++++++++++++++++++++++ tests/test_ssl.py | 2 +- 4 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 tests/test_sha_connection.py diff --git a/aiomysql/connection.py b/aiomysql/connection.py index ef99d483..290712f9 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -798,10 +798,10 @@ async def _process_auth(self, plugin_name, auth_packet): # These auth plugins do their own packet handling if plugin_name == b"caching_sha2_password": await self.caching_sha2_password_auth(auth_packet) - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name.decode() elif plugin_name == b"sha256_password": await self.sha256_password_auth(auth_packet) - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name.decode() else: if plugin_name == b"mysql_native_password": @@ -832,7 +832,7 @@ async def _process_auth(self, plugin_name, auth_packet): pkt = await self._read_packet() pkt.check_error() - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name.decode() return pkt diff --git a/tests/conftest.py b/tests/conftest.py index 5fe3f93a..cbcfa1f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -218,6 +218,14 @@ def docker(): return APIClient(version='auto') +@pytest.fixture(autouse=True) +def ensure_mysql_verison(request, mysql_tag): + if request.node.get_marker('mysql_verison'): + if request.node.get_marker('mysql_verison').args[0] != mysql_tag: + pytest.skip('Not applicable for ' + 'MySQL version: {0}'.format(mysql_tag)) + + @pytest.fixture(scope='session') def mysql_server(unused_port, docker, session_id, mysql_tag, request): if not request.config.option.no_pull: @@ -295,6 +303,32 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request): assert result['Value'].startswith('TLS'), \ "Not connected to the database with TLS" + # Create Databases + cursor.execute('CREATE DATABASE test_pymysql ' + 'DEFAULT CHARACTER SET utf8 ' + 'DEFAULT COLLATE utf8_general_ci;') + cursor.execute('CREATE DATABASE test_pymysql2 ' + 'DEFAULT CHARACTER SET utf8 ' + 'DEFAULT COLLATE utf8_general_ci;') + + # Do MySQL8+ Specific Setup + if mysql_tag in ('8.0',): + # Create Users to test SHA256 + cursor.execute('CREATE USER user_sha256 ' + 'IDENTIFIED WITH "sha256_password" ' + 'BY "pass_sha256"') + cursor.execute('CREATE USER nopass_sha256 ' + 'IDENTIFIED WITH "sha256_password"') + cursor.execute('CREATE USER user_caching_sha2 ' + 'IDENTIFIED ' + 'WITH "caching_sha2_password" ' + 'BY "pass_caching_sha2"') + cursor.execute('CREATE USER nopass_caching_sha2 ' + 'IDENTIFIED ' + 'WITH "caching_sha2_password" ' + 'PASSWORD EXPIRE NEVER') + cursor.execute('FLUSH PRIVILEGES') + break except Exception as err: time.sleep(delay) diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py new file mode 100644 index 00000000..f0ac1ddd --- /dev/null +++ b/tests/test_sha_connection.py @@ -0,0 +1,25 @@ +import copy +from aiomysql import create_pool + +import pytest + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +@pytest.mark.parametrize("user,password,plugin", [ + ("nopass_sha256", None, 'sha256_password'), + ("user_sha256", 'pass_sha256', 'sha256_password'), + ("nopass_caching_sha2", None, 'caching_sha2_password'), + ("user_caching_sha2", 'pass_caching_sha2', 'caching_sha2_password'), +]) +async def test_sha(mysql_server, loop, user, password, plugin): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = user + connection_data['password'] = password + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == plugin diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 044d759e..ff1ea740 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -54,5 +54,5 @@ async def test_auth_plugin_renegotiation(mysql_server, loop): 'Server did not ask for native auth' # Check we actually used the servers default plugin assert conn._auth_plugin_used in ( - b'mysql_native_password', b'caching_sha2_password'), \ + 'mysql_native_password', 'caching_sha2_password'), \ 'Client did not renegotiate with server\'s default auth' From 332a249efc0bb603da670f027a01a53afff36453 Mon Sep 17 00:00:00 2001 From: Terry Cain Date: Sun, 8 Jul 2018 23:36:53 +0100 Subject: [PATCH 5/5] Sped up tests Pytest possibly has some bug in the fixture scoping when theres multiple sets of parameterisation Unrolling the test parameterisation for SHA meant less containers spun up and down. --- tests/conftest.py | 3 ++ tests/test_sha_connection.py | 76 +++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cbcfa1f6..16304adf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,6 +228,8 @@ def ensure_mysql_verison(request, mysql_tag): @pytest.fixture(scope='session') def mysql_server(unused_port, docker, session_id, mysql_tag, request): + print('\nSTARTUP CONTAINER - {0}\n'.format(mysql_tag)) + if not request.config.option.no_pull: docker.pull('mysql:{}'.format(mysql_tag)) @@ -342,5 +344,6 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request): yield container finally: + print('\nTEARDOWN CONTAINER - {0}\n'.format(mysql_tag)) docker.kill(container=container['Id']) docker.remove_container(container['Id']) diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py index f0ac1ddd..f2a108d8 100644 --- a/tests/test_sha_connection.py +++ b/tests/test_sha_connection.py @@ -4,22 +4,78 @@ import pytest +# You could parameterise these tests with this, but then pytest +# does some funky stuff and spins up and tears down containers +# per function call. Remember it would be +# mysql_versions * event_loops * 4 auth tests ~= 3*2*4 ~= 24 tests + +# As the MySQL daemon restarts at least 3 times in the container +# before it becomes stable, there's a sleep(10) so that's +# around a 4min wait time. + +# @pytest.mark.parametrize("user,password,plugin", [ +# ("nopass_sha256", None, 'sha256_password'), +# ("user_sha256", 'pass_sha256', 'sha256_password'), +# ("nopass_caching_sha2", None, 'caching_sha2_password'), +# ("user_caching_sha2", 'pass_caching_sha2', 'caching_sha2_password'), +# ]) + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_sha256_nopw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'nopass_sha256' + connection_data['password'] = None + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'sha256_password' + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_sha256_pw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'user_sha256' + connection_data['password'] = 'pass_sha256' + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'sha256_password' + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_cached_sha256_nopw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'nopass_caching_sha2' + connection_data['password'] = None + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'caching_sha2_password' + + @pytest.mark.mysql_verison('8.0') @pytest.mark.run_loop -@pytest.mark.parametrize("user,password,plugin", [ - ("nopass_sha256", None, 'sha256_password'), - ("user_sha256", 'pass_sha256', 'sha256_password'), - ("nopass_caching_sha2", None, 'caching_sha2_password'), - ("user_caching_sha2", 'pass_caching_sha2', 'caching_sha2_password'), -]) -async def test_sha(mysql_server, loop, user, password, plugin): +async def test_cached_sha256_pw(mysql_server, loop): connection_data = copy.copy(mysql_server['conn_params']) - connection_data['user'] = user - connection_data['password'] = password + connection_data['user'] = 'user_caching_sha2' + connection_data['password'] = 'pass_caching_sha2' async with create_pool(**connection_data, loop=loop) as pool: async with pool.get() as conn: # User doesnt have any permissions to look at DBs # But as 8.0 will default to caching_sha2_password - assert conn._auth_plugin_used == plugin + assert conn._auth_plugin_used == 'caching_sha2_password'