Skip to content

Commit

Permalink
Merge pull request #312 from terrycain/sha256
Browse files Browse the repository at this point in the history
MySQL 8 Compatibility and SHA256 authentication plugin support
  • Loading branch information
terrycain authored Jul 9, 2018
2 parents 425f81e + 332a249 commit 884802c
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 36 deletions.
199 changes: 178 additions & 21 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# from aiomysql.utils import _convert_to_str
from .cursors import Cursor
from .utils import _ConnectionContextManager, _ContextManager
# from .log import logger
from .log import logger


DEFAULT_USER = getpass.getuser()
Expand All @@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
no_delay=None, autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name=''):
program_name='', server_public_key=None):
"""See connections.Connection.__init__() for information about
defaults."""
coro = _connect(host=host, user=user, password=password, db=db,
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
no_delay=None, autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name=''):
program_name='', server_public_key=None):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
(default: Server Default)
:param program_name: Program name string to provide when
handshaking with MySQL. (default: sys.argv[0])
:param server_public_key: SHA256 authentication plugin public
key value.
:param loop: asyncio loop
"""
self._loop = loop or asyncio.get_event_loop()
Expand Down Expand Up @@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
self._client_auth_plugin = auth_plugin
self._server_auth_plugin = ""
self._auth_plugin_used = ""
self.server_public_key = server_public_key
self.salt = None

# TODO somehow import version from __init__.py
self._connect_attrs = {
Expand Down Expand Up @@ -712,6 +716,20 @@ async def _request_authentication(self):
if auth_plugin in ('', 'mysql_native_password'):
authresp = _auth.scramble_native_password(
self._password.encode('latin1'), self.salt)
elif auth_plugin == 'caching_sha2_password':
if self._password:
authresp = _auth.scramble_caching_sha2(
self._password.encode('latin1'), self.salt
)
# Else: empty password
elif auth_plugin == 'sha256_password':
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
authresp = self._password.encode('latin1') + b'\0'
elif self._password:
authresp = b'\1' # request public key
else:
authresp = b'\0' # empty password

elif auth_plugin in ('', 'mysql_clear_password'):
authresp = self._password.encode('latin1') + b'\0'

Expand Down Expand Up @@ -768,35 +786,174 @@ async def _request_authentication(self):
auth_packet.read_all()) + b'\0'
self.write_packet(data)
await self._read_packet()
elif auth_packet.is_extra_auth_data():
if auth_plugin == "caching_sha2_password":
await self.caching_sha2_password_auth(auth_packet)
elif auth_plugin == "sha256_password":
await self.sha256_password_auth(auth_packet)
else:
raise OperationalError("Received extra packet "
"for auth method %r", auth_plugin)

async def _process_auth(self, plugin_name, auth_packet):
if plugin_name == b"mysql_native_password":
# https://dev.mysql.com/doc/internals/en/
# secure-password-authentication.html#packet-Authentication::
# Native41
data = _auth.scramble_native_password(
self._password.encode('latin1'),
auth_packet.read_all())
elif plugin_name == b"mysql_old_password":
# https://dev.mysql.com/doc/internals/en/
# old-password-authentication.html
data = _auth.scramble_old_password(self._password.encode('latin1'),
auth_packet.read_all()) + b'\0'
elif plugin_name == b"mysql_clear_password":
# https://dev.mysql.com/doc/internals/en/
# clear-text-authentication.html
data = self._password.encode('latin1') + b'\0'
# These auth plugins do their own packet handling
if plugin_name == b"caching_sha2_password":
await self.caching_sha2_password_auth(auth_packet)
self._auth_plugin_used = plugin_name.decode()
elif plugin_name == b"sha256_password":
await self.sha256_password_auth(auth_packet)
self._auth_plugin_used = plugin_name.decode()
else:

if plugin_name == b"mysql_native_password":
# https://dev.mysql.com/doc/internals/en/
# secure-password-authentication.html#packet-Authentication::
# Native41
data = _auth.scramble_native_password(
self._password.encode('latin1'),
auth_packet.read_all())
elif plugin_name == b"mysql_old_password":
# https://dev.mysql.com/doc/internals/en/
# old-password-authentication.html
data = _auth.scramble_old_password(
self._password.encode('latin1'),
auth_packet.read_all()
) + b'\0'
elif plugin_name == b"mysql_clear_password":
# https://dev.mysql.com/doc/internals/en/
# clear-text-authentication.html
data = self._password.encode('latin1') + b'\0'
else:
raise OperationalError(
2059, "Authentication plugin '{0}'"
" not configured".format(plugin_name)
)

self.write_packet(data)
pkt = await self._read_packet()
pkt.check_error()

self._auth_plugin_used = plugin_name.decode()

return pkt

async def caching_sha2_password_auth(self, pkt):
# No password fast path
if not self._password:
self.write_packet(b'')
pkt = await self._read_packet()
pkt.check_error()
return pkt

if pkt.is_auth_switch_request():
# Try from fast auth
logger.debug("caching sha2: Trying fast path")
self.salt = pkt.read_all()
scrambled = _auth.scramble_caching_sha2(
self._password.encode('latin1'), self.salt
)

self.write_packet(scrambled)
pkt = await self._read_packet()
pkt.check_error()

# else: fast auth is tried in initial handshake

if not pkt.is_extra_auth_data():
raise OperationalError(
2059, "Authentication plugin '%s' not configured" % plugin_name
"caching sha2: Unknown packet "
"for fast auth: {0}".format(pkt._data[:1])
)

# magic numbers:
# 2 - request public key
# 3 - fast auth succeeded
# 4 - need full auth

pkt.advance(1)
n = pkt.read_uint8()

if n == 3:
logger.debug("caching sha2: succeeded by fast path.")
pkt = await self._read_packet()
pkt.check_error() # pkt must be OK packet
return pkt

if n != 4:
raise OperationalError("caching sha2: Unknown "
"result for fast auth: {0}".format(n))

logger.debug("caching sha2: Trying full auth...")

if self._ssl_context:
logger.debug("caching sha2: Sending plain "
"password via secure connection")
self.write_packet(self._password.encode('latin1') + b'\0')
pkt = await self._read_packet()
pkt.check_error()
return pkt

if not self.server_public_key:
self.write_packet(b'\x02')
pkt = await self._read_packet() # Request public key
pkt.check_error()

if not pkt.is_extra_auth_data():
raise OperationalError(
"caching sha2: Unknown packet "
"for public key: {0}".format(pkt._data[:1])
)

self.server_public_key = pkt._data[1:]
logger.debug(self.server_public_key.decode('ascii'))

data = _auth.sha2_rsa_encrypt(
self._password.encode('latin1'), self.salt,
self.server_public_key
)
self.write_packet(data)
pkt = await self._read_packet()
pkt.check_error()

self._auth_plugin_used = plugin_name
async def sha256_password_auth(self, pkt):
if self._ssl_context:
logger.debug("sha256: Sending plain password")
data = self._password.encode('latin1') + b'\0'
self.write_packet(data)
pkt = await self._read_packet()
pkt.check_error()
return pkt

if pkt.is_auth_switch_request():
self.salt = pkt.read_all()
if not self.server_public_key and self._password:
# Request server public key
logger.debug("sha256: Requesting server public key")
self.write_packet(b'\1')
pkt = await self._read_packet()
pkt.check_error()

if pkt.is_extra_auth_data():
self.server_public_key = pkt._data[1:]
logger.debug(
"Received public key:\n",
self.server_public_key.decode('ascii')
)

if self._password:
if not self.server_public_key:
raise OperationalError("Couldn't receive server's public key")

data = _auth.sha2_rsa_encrypt(
self._password.encode('latin1'), self.salt,
self.server_public_key
)
else:
data = b''

self.write_packet(data)
pkt = await self._read_packet()
pkt.check_error()
return pkt

# _mysql support
Expand Down
4 changes: 3 additions & 1 deletion docs/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ Example::
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
no_delay=False, autocommit=False, echo=False,
ssl=None, auth_plugin='', program_name='', loop=None)
ssl=None, auth_plugin='', program_name='',
server_public_key=None, loop=None)

A :ref:`coroutine <coroutine>` that connects to MySQL.

Expand Down Expand Up @@ -89,6 +90,7 @@ Example::
(default: Server Default)
:param program_name: Program name string to provide when
handshaking with MySQL. (default: sys.argv[0])
:param server_public_key: SHA256 authenticaiton plugin public key value.
:param loop: asyncio event loop instance or ``None`` for default one.
:returns: :class:`Connection` instance.

Expand Down
38 changes: 38 additions & 0 deletions examples/example_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
import ssl
import aiomysql

ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx.check_hostname = False
ctx.load_verify_locations(cafile='../tests/ssl_resources/ssl/ca.pem')


async def main():
async with aiomysql.create_pool(
host='localhost', port=3306, user='root',
password='rootpw', ssl=ctx,
auth_plugin='mysql_clear_password') as pool:

async with pool.get() as conn:
async with conn.cursor() as cur:
# Run simple command
await cur.execute("SHOW DATABASES;")
value = await cur.fetchall()

values = [item[0] for item in value]
# Spot check the answers, we should at least have mysql
# and information_schema
assert 'mysql' in values, \
'Could not find the "mysql" table'
assert 'information_schema' in values, \
'Could not find the "mysql" table'

# Check TLS variables
await cur.execute("SHOW STATUS LIKE 'Ssl_version%';")
value = await cur.fetchone()

# The context has TLS
assert value[1].startswith('TLS'), \
'Not connected to the database with TLS'

asyncio.get_event_loop().run_until_complete(main())
Loading

0 comments on commit 884802c

Please sign in to comment.