Skip to content

Commit

Permalink
Reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
damian3031 committed Jan 20, 2023
1 parent a45c000 commit 435a30f
Showing 1 changed file with 79 additions and 34 deletions.
113 changes: 79 additions & 34 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,30 @@
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal
from time import sleep
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import pytz
import requests
from pytz.tzinfo import BaseTzInfo
from tzlocal import get_localzone_name # type: ignore
from tzlocal import get_localzone_name

import trino.logging
from trino import constants, exceptions

try:
from zoneinfo import ZoneInfo # type: ignore
from zoneinfo import ZoneInfo

except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo # type: ignore
Expand All @@ -75,7 +87,7 @@
else:
PROXIES = {}

_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$")

T = TypeVar("T")

Expand Down Expand Up @@ -461,8 +473,13 @@ def http_headers(self) -> Dict[str, str]:
"{}={}".format(catalog, urllib.parse.quote(str(role)))
for catalog, role in self._client_session.roles.items()
)
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
if (
self._client_session.client_tags is not None
and len(self._client_session.client_tags) > 0
):
headers[constants.HEADER_CLIENT_TAGS] = ",".join(
self._client_session.client_tags
)

headers[constants.HEADER_SESSION] = ",".join(
# ``name`` must not contain ``=``
Expand All @@ -486,18 +503,23 @@ def http_headers(self) -> Dict[str, str]:
transaction_id = self._client_session.transaction_id
headers[constants.HEADER_TRANSACTION] = transaction_id

if self._client_session.extra_credential is not None and \
len(self._client_session.extra_credential) > 0:
if (
self._client_session.extra_credential is not None
and len(self._client_session.extra_credential) > 0
):

for tup in self._client_session.extra_credential:
self._verify_extra_credential(tup)

# HTTP 1.1 section 4.2 combine multiple extra credentials into a
# comma-separated value
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
headers[constants.HEADER_EXTRA_CREDENTIAL] = \
", ".join(
[f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential])
headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join(
[
f"{tup[0]}={urllib.parse.quote_plus(tup[1])}"
for tup in self._client_session.extra_credential
]
)

return headers

Expand Down Expand Up @@ -562,7 +584,12 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
while http_response is not None and http_response.is_redirect:
location = http_response.headers["Location"]
url = self._redirect_handler.handle(location)
logger.info("redirect %s from %s to %s", http_response.status_code, location, url)
logger.info(
"redirect %s from %s to %s",
http_response.status_code,
location,
url,
)
http_response = self._post(
url,
data=data,
Expand Down Expand Up @@ -606,7 +633,7 @@ def raise_response_error(self, http_response):
raise exceptions.HttpError(
"error {}{}".format(
http_response.status_code,
": {}".format(http_response.content) if http_response.content else "",
": {}".format(repr(http_response.content)) if http_response.content else "",
)
)

Expand All @@ -633,14 +660,18 @@ def process(self, http_response) -> TrinoStatus:
self._client_session.properties[key] = value

if constants.HEADER_SET_CATALOG in http_response.headers:
self._client_session.catalog = http_response.headers[constants.HEADER_SET_CATALOG]
self._client_session.catalog = http_response.headers[
constants.HEADER_SET_CATALOG
]

if constants.HEADER_SET_SCHEMA in http_response.headers:
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]
self._client_session.schema = http_response.headers[
constants.HEADER_SET_SCHEMA
]

if constants.HEADER_SET_ROLE in http_response.headers:
for key, value in get_roles_values(
http_response.headers, constants.HEADER_SET_ROLE
http_response.headers, constants.HEADER_SET_ROLE
):
self._client_session.roles[key] = value

Expand Down Expand Up @@ -676,12 +707,16 @@ def _verify_extra_credential(self, header):
key = header[0]

if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX.match(key):
raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'")
raise ValueError(
f"whitespace or '=' are disallowed in extra credential '{key}'"
)

try:
key.encode().decode('ascii')
key.encode().decode("ascii")
except UnicodeDecodeError:
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
raise ValueError(
f"only ASCII characters are allowed in extra credential '{key}'"
)


class TrinoResult(object):
Expand Down Expand Up @@ -847,7 +882,10 @@ def cancel(self) -> None:

def is_finished(self) -> bool:
import warnings
warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning)

warnings.warn(
"is_finished is deprecated, use finished instead", DeprecationWarning
)
return self.finished

@property
Expand Down Expand Up @@ -910,11 +948,11 @@ class DoubleValueMapper(ValueMapper[float]):
def map(self, value) -> Optional[float]:
if value is None:
return None
if value == 'Infinity':
if value == "Infinity":
return float("inf")
if value == '-Infinity':
if value == "-Infinity":
return float("-inf")
if value == 'NaN':
if value == "NaN":
return float("nan")
return float(value)

Expand Down Expand Up @@ -1119,7 +1157,9 @@ def __init__(self, mappers: List[ValueMapper[Any]]):
def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
if values is None:
return None
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))
return tuple(
self.mappers[index].map(value) for index, value in enumerate(values)
)


class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
Expand All @@ -1131,7 +1171,8 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
if values is None:
return None
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
self.key_mapper.map(key): self.value_mapper.map(value)
for key, value in values.items()
}


Expand All @@ -1151,6 +1192,7 @@ class RowMapperFactory:
lambda functions (one for each column) which will process a data value
and returns a RowMapper instance which will process rows of data
"""

NO_OP_ROW_MAPPER = NoOpRowMapper()

def create(self, columns, legacy_primitive_types):
Expand All @@ -1163,19 +1205,22 @@ def create(self, columns, legacy_primitive_types):
def _create_value_mapper(self, column) -> ValueMapper:
col_type = column['rawType']

if col_type == 'array':
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
if col_type == "array":
value_mapper = self._create_value_mapper(column["arguments"][0]["value"])
return ArrayValueMapper(value_mapper)
elif col_type == 'row':
mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']]
elif col_type == "row":
mappers = [
self._create_value_mapper(arg["value"]["typeSignature"])
for arg in column["arguments"]
]
return RowValueMapper(mappers)
elif col_type == 'map':
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
elif col_type == "map":
key_mapper = self._create_value_mapper(column["arguments"][0]["value"])
value_mapper = self._create_value_mapper(column["arguments"][1]["value"])
return MapValueMapper(key_mapper, value_mapper)
elif col_type.startswith('decimal'):
elif col_type.startswith("decimal"):
return DecimalValueMapper()
elif col_type.startswith('double') or col_type.startswith('real'):
elif col_type.startswith("double") or col_type.startswith("real"):
return DoubleValueMapper()
elif col_type.startswith('timestamp') and 'with time zone' in col_type:
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
Expand Down

0 comments on commit 435a30f

Please sign in to comment.