diff --git a/trino/client.py b/trino/client.py index 6f58a9eb..932d425d 100644 --- a/trino/client.py +++ b/trino/client.py @@ -47,18 +47,30 @@ from datetime import date, datetime, time, timedelta, timezone, tzinfo from decimal import Decimal from time import sleep -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import pytz import requests from pytz.tzinfo import BaseTzInfo -from tzlocal import get_localzone_name # type: ignore +from tzlocal import get_localzone_name import trino.logging from trino import constants, exceptions try: - from zoneinfo import ZoneInfo # type: ignore + from zoneinfo import ZoneInfo except ModuleNotFoundError: from backports.zoneinfo import ZoneInfo # type: ignore @@ -75,7 +87,7 @@ else: PROXIES = {} -_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') +_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$") T = TypeVar("T") @@ -461,8 +473,13 @@ def http_headers(self) -> Dict[str, str]: "{}={}".format(catalog, urllib.parse.quote(str(role))) for catalog, role in self._client_session.roles.items() ) - if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0: - headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags) + if ( + self._client_session.client_tags is not None + and len(self._client_session.client_tags) > 0 + ): + headers[constants.HEADER_CLIENT_TAGS] = ",".join( + self._client_session.client_tags + ) headers[constants.HEADER_SESSION] = ",".join( # ``name`` must not contain ``=`` @@ -486,8 +503,10 @@ def http_headers(self) -> Dict[str, str]: transaction_id = self._client_session.transaction_id headers[constants.HEADER_TRANSACTION] = transaction_id - if self._client_session.extra_credential is not None and \ - len(self._client_session.extra_credential) > 0: + if ( + self._client_session.extra_credential is not None + and len(self._client_session.extra_credential) > 0 + ): for tup in self._client_session.extra_credential: self._verify_extra_credential(tup) @@ -495,9 +514,12 @@ def http_headers(self) -> Dict[str, str]: # HTTP 1.1 section 4.2 combine multiple extra credentials into a # comma-separated value # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) - headers[constants.HEADER_EXTRA_CREDENTIAL] = \ - ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential]) + headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join( + [ + f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" + for tup in self._client_session.extra_credential + ] + ) return headers @@ -562,7 +584,12 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non while http_response is not None and http_response.is_redirect: location = http_response.headers["Location"] url = self._redirect_handler.handle(location) - logger.info("redirect %s from %s to %s", http_response.status_code, location, url) + logger.info( + "redirect %s from %s to %s", + http_response.status_code, + location, + url, + ) http_response = self._post( url, data=data, @@ -606,7 +633,7 @@ def raise_response_error(self, http_response): raise exceptions.HttpError( "error {}{}".format( http_response.status_code, - ": {}".format(http_response.content) if http_response.content else "", + ": {}".format(repr(http_response.content)) if http_response.content else "", ) ) @@ -633,14 +660,18 @@ def process(self, http_response) -> TrinoStatus: self._client_session.properties[key] = value if constants.HEADER_SET_CATALOG in http_response.headers: - self._client_session.catalog = http_response.headers[constants.HEADER_SET_CATALOG] + self._client_session.catalog = http_response.headers[ + constants.HEADER_SET_CATALOG + ] if constants.HEADER_SET_SCHEMA in http_response.headers: - self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA] + self._client_session.schema = http_response.headers[ + constants.HEADER_SET_SCHEMA + ] if constants.HEADER_SET_ROLE in http_response.headers: for key, value in get_roles_values( - http_response.headers, constants.HEADER_SET_ROLE + http_response.headers, constants.HEADER_SET_ROLE ): self._client_session.roles[key] = value @@ -676,12 +707,16 @@ def _verify_extra_credential(self, header): key = header[0] if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX.match(key): - raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'") + raise ValueError( + f"whitespace or '=' are disallowed in extra credential '{key}'" + ) try: - key.encode().decode('ascii') + key.encode().decode("ascii") except UnicodeDecodeError: - raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") + raise ValueError( + f"only ASCII characters are allowed in extra credential '{key}'" + ) class TrinoResult(object): @@ -847,7 +882,10 @@ def cancel(self) -> None: def is_finished(self) -> bool: import warnings - warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning) + + warnings.warn( + "is_finished is deprecated, use finished instead", DeprecationWarning + ) return self.finished @property @@ -910,11 +948,11 @@ class DoubleValueMapper(ValueMapper[float]): def map(self, value) -> Optional[float]: if value is None: return None - if value == 'Infinity': + if value == "Infinity": return float("inf") - if value == '-Infinity': + if value == "-Infinity": return float("-inf") - if value == 'NaN': + if value == "NaN": return float("nan") return float(value) @@ -1119,7 +1157,9 @@ def __init__(self, mappers: List[ValueMapper[Any]]): def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: if values is None: return None - return tuple(self.mappers[index].map(value) for index, value in enumerate(values)) + return tuple( + self.mappers[index].map(value) for index, value in enumerate(values) + ) class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): @@ -1131,7 +1171,8 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: if values is None: return None return { - self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() + self.key_mapper.map(key): self.value_mapper.map(value) + for key, value in values.items() } @@ -1151,6 +1192,7 @@ class RowMapperFactory: lambda functions (one for each column) which will process a data value and returns a RowMapper instance which will process rows of data """ + NO_OP_ROW_MAPPER = NoOpRowMapper() def create(self, columns, legacy_primitive_types): @@ -1163,19 +1205,22 @@ def create(self, columns, legacy_primitive_types): def _create_value_mapper(self, column) -> ValueMapper: col_type = column['rawType'] - if col_type == 'array': - value_mapper = self._create_value_mapper(column['arguments'][0]['value']) + if col_type == "array": + value_mapper = self._create_value_mapper(column["arguments"][0]["value"]) return ArrayValueMapper(value_mapper) - elif col_type == 'row': - mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']] + elif col_type == "row": + mappers = [ + self._create_value_mapper(arg["value"]["typeSignature"]) + for arg in column["arguments"] + ] return RowValueMapper(mappers) - elif col_type == 'map': - key_mapper = self._create_value_mapper(column['arguments'][0]['value']) - value_mapper = self._create_value_mapper(column['arguments'][1]['value']) + elif col_type == "map": + key_mapper = self._create_value_mapper(column["arguments"][0]["value"]) + value_mapper = self._create_value_mapper(column["arguments"][1]["value"]) return MapValueMapper(key_mapper, value_mapper) - elif col_type.startswith('decimal'): + elif col_type.startswith("decimal"): return DecimalValueMapper() - elif col_type.startswith('double') or col_type.startswith('real'): + elif col_type.startswith("double") or col_type.startswith("real"): return DoubleValueMapper() elif col_type.startswith('timestamp') and 'with time zone' in col_type: return TimestampWithTimeZoneValueMapper(self._get_precision(column))