From 2d7cd4ec5907ad183ec089edf3e54e1b1aec8832 Mon Sep 17 00:00:00 2001 From: Huw Date: Tue, 7 May 2024 07:59:13 +0000 Subject: [PATCH] Support object as value in extra_credential --- tests/unit/test_client.py | 36 ++++++++++++++++++++++++++++++++++++ trino/client.py | 3 ++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 48ac3948..7305d164 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -866,6 +866,42 @@ def test_extra_credential_value_encoding(mock_get_and_post): assert constants.HEADER_EXTRA_CREDENTIAL in headers assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84" +def test_extra_credential_value_object(mock_get_and_post): + _, post = mock_get_and_post + + class TestCredential(object): + value = 0 + + def __str__(self): + self.value = self.value + 1 + return str(self.value) + + credential = TestCredential() + + req = TrinoRequest( + host="coordinator", + port=constants.DEFAULT_TLS_PORT, + client_session=ClientSession( + user="test", + extra_credential=[("foo", credential)] + ) + ) + + req.post("SELECT 1") + _, post_kwargs = post.call_args + headers = post_kwargs["headers"] + assert constants.HEADER_EXTRA_CREDENTIAL in headers + assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=initial" + + credential.value = "changed" + + # Make a second request, assert that credential has changed + req.post("SELECT 1") + _, post_kwargs = post.call_args + headers = post_kwargs["headers"] + assert constants.HEADER_EXTRA_CREDENTIAL in headers + assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=changed" + class MockGssapiCredentials: def __init__(self, name: gssapi.Name, usage: str): diff --git a/trino/client.py b/trino/client.py index 763e0ebc..cbcd15af 100644 --- a/trino/client.py +++ b/trino/client.py @@ -486,7 +486,8 @@ def http_headers(self) -> Dict[str, str]: # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) headers[constants.HEADER_EXTRA_CREDENTIAL] = \ ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential]) + [f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}" + for tup in self._client_session.extra_credential]) return headers