Skip to content

Commit

Permalink
Support SET SESSION AUTHORIZATION on trino-python-client
Browse files Browse the repository at this point in the history
  • Loading branch information
baohe-zhang authored and hashhar committed Jun 24, 2024
1 parent 856d8e9 commit 169226e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_request_headers(mock_get_and_post):
catalog = "test_catalog"
schema = "test_schema"
user = "test_user"
authorization_user = "test_authorization_user"
source = "test_source"
timezone = "Europe/Brussels"
accept_encoding_header = "accept-encoding"
Expand All @@ -103,6 +104,7 @@ def test_request_headers(mock_get_and_post):
port=8080,
client_session=ClientSession(
user=user,
authorization_user=authorization_user,
source=source,
catalog=catalog,
schema=schema,
Expand All @@ -127,6 +129,7 @@ def assert_headers(headers):
assert headers[constants.HEADER_SCHEMA] == schema
assert headers[constants.HEADER_SOURCE] == source
assert headers[constants.HEADER_USER] == user
assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user
assert headers[constants.HEADER_SESSION] == ""
assert headers[constants.HEADER_TRANSACTION] is None
assert headers[constants.HEADER_TIMEZONE] == timezone
Expand All @@ -140,7 +143,7 @@ def assert_headers(headers):
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
)
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
assert len(headers.keys()) == 12
assert len(headers.keys()) == 13

req.post("URL")
_, post_kwargs = post.call_args
Expand Down
21 changes: 21 additions & 0 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class ClientSession(object):
:param user: associated with the query. It is useful for access control
and query scheduling.
:param authorization_user: associated with the query. It is useful for access control
and query scheduling.
:param source: associated with the query. It is useful for access
control and query scheduling.
:param catalog: to query. The *catalog* is associated with a Trino
Expand Down Expand Up @@ -113,6 +115,7 @@ class ClientSession(object):
def __init__(
self,
user: str,
authorization_user: str = None,
catalog: str = None,
schema: str = None,
source: str = None,
Expand All @@ -125,6 +128,7 @@ def __init__(
timezone: str = None,
):
self._user = user
self._authorization_user = authorization_user
self._catalog = catalog
self._schema = schema
self._source = source
Expand All @@ -144,6 +148,16 @@ def __init__(
def user(self):
return self._user

@property
def authorization_user(self):
with self._object_lock:
return self._authorization_user

@authorization_user.setter
def authorization_user(self, authorization_user):
with self._object_lock:
self._authorization_user = authorization_user

@property
def catalog(self):
with self._object_lock:
Expand Down Expand Up @@ -441,6 +455,7 @@ def http_headers(self) -> Dict[str, str]:
headers[constants.HEADER_SCHEMA] = self._client_session.schema
headers[constants.HEADER_SOURCE] = self._client_session.source
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}"
Expand Down Expand Up @@ -631,6 +646,12 @@ def process(self, http_response) -> TrinoStatus:
):
self._client_session.prepared_statements.pop(name, None)

if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers:
self._client_session.authorization_user = http_response.headers[constants.HEADER_SET_AUTHORIZATION_USER]

if constants.HEADER_RESET_AUTHORIZATION_USER in http_response.headers:
self._client_session.authorization_user = None

self._next_uri = response.get("nextUri")

data = response.get("data") if response.get("data") else []
Expand Down
4 changes: 4 additions & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@

HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"

HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User"
HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User"
HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User"

LENGTH_TYPES = ["char", "varchar"]
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
SCALE_TYPES = ["decimal"]

0 comments on commit 169226e

Please sign in to comment.