diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index 413a6bb0..7fd021f2 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -298,6 +298,23 @@ def test_trino_connection_certificate_auth(): assert cparams['auth']._key == key +def test_trino_connection_certificate_auth_cert_and_key_required(): + dialect = TrinoDialect() + cert = '/path/to/cert.pem' + key = '/path/to/key.pem' + url = make_url(f'trino://host/?cert={cert}') + _, cparams = dialect.create_connect_args(url) + + assert 'http_scheme' not in cparams + assert 'auth' not in cparams + + url = make_url(f'trino://host/?key={key}') + _, cparams = dialect.create_connect_args(url) + + assert 'http_scheme' not in cparams + assert 'auth' not in cparams + + def test_trino_connection_oauth2_auth(): dialect = TrinoDialect() url = make_url('trino://host/?externalAuthentication=true') diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index edcc372c..cc3c313d 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -131,7 +131,7 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any kwargs["http_scheme"] = "https" kwargs["auth"] = JWTAuthentication(unquote_plus(url.query["access_token"])) - if "cert" and "key" in url.query: + if "cert" in url.query and "key" in url.query: kwargs["http_scheme"] = "https" kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key']))