diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 48ac3948..19c107f9 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -90,6 +90,7 @@ def test_request_headers(mock_get_and_post): catalog = "test_catalog" schema = "test_schema" user = "test_user" + authorization_user = "test_authorization_user" source = "test_source" timezone = "Europe/Brussels" accept_encoding_header = "accept-encoding" @@ -103,6 +104,7 @@ def test_request_headers(mock_get_and_post): port=8080, client_session=ClientSession( user=user, + authorization_user=authorization_user, source=source, catalog=catalog, schema=schema, @@ -127,6 +129,7 @@ def assert_headers(headers): assert headers[constants.HEADER_SCHEMA] == schema assert headers[constants.HEADER_SOURCE] == source assert headers[constants.HEADER_USER] == user + assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user assert headers[constants.HEADER_SESSION] == "" assert headers[constants.HEADER_TRANSACTION] is None assert headers[constants.HEADER_TIMEZONE] == timezone @@ -140,7 +143,7 @@ def assert_headers(headers): "catalog2=" + urllib.parse.quote("ROLE{catalog2_role}") ) assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}" - assert len(headers.keys()) == 12 + assert len(headers.keys()) == 13 req.post("URL") _, post_kwargs = post.call_args diff --git a/trino/client.py b/trino/client.py index 763e0ebc..017555a1 100644 --- a/trino/client.py +++ b/trino/client.py @@ -82,6 +82,8 @@ class ClientSession(object): :param user: associated with the query. It is useful for access control and query scheduling. + :param authorization_user: associated with the query. It is useful for access control + and query scheduling. :param source: associated with the query. It is useful for access control and query scheduling. :param catalog: to query. The *catalog* is associated with a Trino @@ -113,6 +115,7 @@ class ClientSession(object): def __init__( self, user: str, + authorization_user: str = None, catalog: str = None, schema: str = None, source: str = None, @@ -125,6 +128,7 @@ def __init__( timezone: str = None, ): self._user = user + self._authorization_user = authorization_user self._catalog = catalog self._schema = schema self._source = source @@ -144,6 +148,16 @@ def __init__( def user(self): return self._user + @property + def authorization_user(self): + with self._object_lock: + return self._authorization_user + + @authorization_user.setter + def authorization_user(self, authorization_user): + with self._object_lock: + self._authorization_user = authorization_user + @property def catalog(self): with self._object_lock: @@ -441,6 +455,7 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_SCHEMA] = self._client_session.schema headers[constants.HEADER_SOURCE] = self._client_session.source headers[constants.HEADER_USER] = self._client_session.user + headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME' headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}" @@ -630,6 +645,12 @@ def process(self, http_response) -> TrinoStatus: ): self._client_session.prepared_statements.pop(name, None) + if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers: + self._client_session.authorization_user = http_response.headers[constants.HEADER_SET_AUTHORIZATION_USER] + + if constants.HEADER_RESET_AUTHORIZATION_USER in http_response.headers: + self._client_session.authorization_user = None + self._next_uri = response.get("nextUri") data = response.get("data") if response.get("data") else [] diff --git a/trino/constants.py b/trino/constants.py index 2199105c..d4ba904d 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -56,6 +56,10 @@ HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities" +HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User" +HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User" +HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User" + LENGTH_TYPES = ["char", "varchar"] PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"] SCALE_TYPES = ["decimal"]