From 15618344caa314f9eb9ad03ac93e01d7174da22b Mon Sep 17 00:00:00 2001 From: Mel Dafert Date: Fri, 19 Aug 2022 13:56:59 +0200 Subject: [PATCH] cli option --tls-version --- README.md | 3 +++ changelog.md | 4 ++++ mycli/AUTHORS | 1 + mycli/main.py | 8 ++++++-- mycli/sqlexecute.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e15f5057..35b03092 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,9 @@ $ sudo apt-get install mycli # Only on debian or ubuntu --ssl-cert PATH X509 cert in PEM format. --ssl-key PATH X509 key in PEM format. --ssl-cipher TEXT SSL cipher to use. + --tls-version [TLSv1|TLSv1.1|TLSv1.2|TLSv1.3] + TLS protocol version for secure connection. + --ssl-verify-server-cert Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default. diff --git a/changelog.md b/changelog.md index 159299d0..52d6f8f9 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,10 @@ TBD === +Features: +--------- +* Add `--tls-version` option to control the tls version used. + Internal: --------- * Pin `cryptography` to suppress `paramiko` warning, helping CI complete and presumably affecting some users. diff --git a/mycli/AUTHORS b/mycli/AUTHORS index 328805dd..dd276867 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -92,6 +92,7 @@ Contributors: * Zhongyang Guan * Arvind Mishra * Kevin Schmeichel + * Mel Dafert Created by: ----------- diff --git a/mycli/main.py b/mycli/main.py index 0561af81..6d3641db 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1097,6 +1097,9 @@ def get_last_query(self): @click.option('--ssl-key', help='X509 key in PEM format.', type=click.Path(exists=True)) @click.option('--ssl-cipher', help='SSL cipher to use.') +@click.option('--tls-version', + type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), + help='TLS protocol version for secure connection.') @click.option('--ssl-verify-server-cert', is_flag=True, help=('Verify server\'s "Common Name" in its cert against ' 'hostname used when connecting. This option is disabled ' @@ -1148,8 +1151,8 @@ def cli(database, user, host, port, socket, password, dbname, version, verbose, prompt, logfile, defaults_group_suffix, defaults_file, login_path, auto_vertical_output, local_infile, ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, - ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, - list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, + tls_version, ssl_verify_server_cert, table, csv, warn, execute, + myclirc, dsn, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, init_command, charset, password_file): """A MySQL terminal client with auto-completion and syntax highlighting. @@ -1207,6 +1210,7 @@ def cli(database, user, host, port, socket, password, dbname, 'key': ssl_key and os.path.expanduser(ssl_key), 'capath': ssl_capath, 'cipher': ssl_cipher, + 'tls_version': tls_version, 'check_hostname': ssl_verify_server_cert, } diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index c0197079..38de4f09 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -176,11 +176,15 @@ def connect(self, database=None, user=None, password=None, host=None, if init_command and len(list(special.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + ssl_context = None + if ssl: + ssl_context = self._create_ssl_ctx(ssl) + conn = pymysql.connect( database=db, user=user, password=password, host=host, port=port, unix_socket=socket, use_unicode=True, charset=charset, autocommit=True, client_flag=client_flag, - local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", + local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli", defer_connect=defer_connect, init_command=init_command ) @@ -354,3 +358,40 @@ def reset_connection_id(self): def change_db(self, db): self.conn.select_db(db) self.dbname = db + + def _create_ssl_ctx(self, sslp): + import ssl + + ca = sslp.get("ca") + capath = sslp.get("capath") + hasnoca = ca is None and capath is None + ctx = ssl.create_default_context(cafile=ca, capath=capath) + ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + if "cert" in sslp: + ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key")) + if "cipher" in sslp: + ctx.set_ciphers(sslp["cipher"]) + + # raise this default to v1.1 or v1.2? + ctx.minimum_version = ssl.TLSVersion.TLSv1 + + if "tls_version" in sslp: + tls_version = sslp["tls_version"] + + if tls_version == "TLSv1": + ctx.minimum_version = ssl.TLSVersion.TLSv1 + ctx.maximum_version = ssl.TLSVersion.TLSv1 + elif tls_version == "TLSv1.1": + ctx.minimum_version = ssl.TLSVersion.TLSv1_1 + ctx.maximum_version = ssl.TLSVersion.TLSv1_1 + elif tls_version == "TLSv1.2": + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + elif tls_version == "TLSv1.3": + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + ctx.maximum_version = ssl.TLSVersion.TLSv1_3 + else: + _logger.error('Invalid tls version: %s', tls_version) + + return ctx