From 97b777365abcf7fffad701bdadd19b63c07653da Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Thu, 19 Jan 2023 10:29:10 +0100 Subject: [PATCH 1/5] mypy checks --- setup.cfg | 2 +- trino/client.py | 10 ++-- trino/dbapi.py | 146 ++++++++++++++++++++++++------------------------ 3 files changed, 80 insertions(+), 78 deletions(-) diff --git a/setup.cfg b/setup.cfg index 372d84b0..b473f771 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,5 @@ ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*] +[mypy-tests.*,trino.client,trino.sqlalchemy.*,trino.dbapi] ignore_errors = true diff --git a/trino/client.py b/trino/client.py index 3ab35b09..ac407363 100644 --- a/trino/client.py +++ b/trino/client.py @@ -125,10 +125,10 @@ class ClientSession(object): def __init__( self, - user: str, - catalog: str = None, - schema: str = None, - source: str = None, + user: Optional[str], + catalog: Optional[str] = None, + schema: Optional[str] = None, + source: Optional[str] = None, properties: Dict[str, str] = None, headers: Dict[str, str] = None, transaction_id: str = None, @@ -401,7 +401,7 @@ def __init__( auth: Optional[Any] = constants.DEFAULT_AUTH, redirect_handler: Any = None, max_attempts: int = MAX_ATTEMPTS, - request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT, + request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT, handle_retry=_RetryWithExponentialBackoff(), verify: bool = True, ) -> None: diff --git a/trino/dbapi.py b/trino/dbapi.py index 6cb2a97a..206a6fe0 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -20,9 +20,11 @@ import binascii import datetime import math +import time import uuid from decimal import Decimal -from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types +from types import TracebackType +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union import trino.client import trino.exceptions @@ -72,7 +74,7 @@ logger = trino.logging.get_logger(__name__) -def connect(*args, **kwargs): +def connect(*args: Any, **kwargs: Any) -> trino.dbapi.Connection: """Constructor for creating a connection to the database. See class :py:class:`Connection` for arguments. @@ -92,28 +94,28 @@ class Connection(object): def __init__( self, - host, - port=constants.DEFAULT_PORT, - user=None, - source=constants.DEFAULT_SOURCE, - catalog=constants.DEFAULT_CATALOG, - schema=constants.DEFAULT_SCHEMA, - session_properties=None, - http_headers=None, - http_scheme=constants.HTTP, - auth=constants.DEFAULT_AUTH, - extra_credential=None, - redirect_handler=None, - max_attempts=constants.DEFAULT_MAX_ATTEMPTS, - request_timeout=constants.DEFAULT_REQUEST_TIMEOUT, - isolation_level=IsolationLevel.AUTOCOMMIT, - verify=True, - http_session=None, - client_tags=None, - legacy_primitive_types=False, - roles=None, + host: str, + port: int = constants.DEFAULT_PORT, + user: Optional[str] = None, + source: str = constants.DEFAULT_SOURCE, + catalog: Optional[str] = constants.DEFAULT_CATALOG, + schema: Optional[str] = constants.DEFAULT_SCHEMA, + session_properties: Optional[Dict[str, str]] = None, + http_headers: Optional[Dict[str, str]] = None, + http_scheme: str = constants.HTTP, + auth: Optional[trino.auth.Authentication] = constants.DEFAULT_AUTH, + extra_credential: Optional[List[Tuple[str, str]]] = None, + redirect_handler: Optional[str] = None, + max_attempts: int = constants.DEFAULT_MAX_ATTEMPTS, + request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT, + isolation_level: IsolationLevel = IsolationLevel.AUTOCOMMIT, + verify: Union[bool | str] = True, + http_session: Optional[trino.client.TrinoRequest.http.Session] = None, + client_tags: Optional[List[str]] = None, + legacy_primitive_types: Optional[bool] = False, + roles: Optional[Dict[str, str]] = None, timezone=None, - ): + ) -> None: self.host = host self.port = port self.user = user @@ -151,21 +153,24 @@ def __init__( self._isolation_level = isolation_level self._request = None - self._transaction = None + self._transaction: Optional[Transaction] = None self.legacy_primitive_types = legacy_primitive_types @property - def isolation_level(self): + def isolation_level(self) -> IsolationLevel: return self._isolation_level @property - def transaction(self): + def transaction(self) -> Optional[Transaction]: return self._transaction - def __enter__(self): + def __enter__(self) -> object: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType]) -> None: try: self.commit() except Exception: @@ -173,28 +178,28 @@ def __exit__(self, exc_type, exc_value, traceback): else: self.close() - def close(self): + def close(self) -> None: # TODO cancel outstanding queries? self._http_session.close() - def start_transaction(self): + def start_transaction(self) -> Transaction: self._transaction = Transaction(self._create_request()) self._transaction.begin() return self._transaction - def commit(self): + def commit(self) -> None: if self.transaction is None: return - self._transaction.commit() + self.transaction.commit() self._transaction = None - def rollback(self): + def rollback(self) -> None: if self.transaction is None: raise RuntimeError("no transaction was started") - self._transaction.rollback() + self.transaction.rollback() self._transaction = None - def _create_request(self): + def _create_request(self) -> trino.client.TrinoRequest: return trino.client.TrinoRequest( self.host, self.port, @@ -207,7 +212,7 @@ def _create_request(self): self.request_timeout, ) - def cursor(self, legacy_primitive_types: bool = None): + def cursor(self, legacy_primitive_types: bool = None) -> 'trino.dbapi.Cursor': """Return a new :py:class:`Cursor` object using the connection.""" if self.isolation_level != IsolationLevel.AUTOCOMMIT: if self.transaction is None: @@ -271,7 +276,10 @@ class Cursor(object): """ - def __init__(self, connection, request, legacy_primitive_types: bool = False): + def __init__(self, + connection: Connection, + request: trino.client.TrinoRequest, + legacy_primitive_types: bool = False) -> None: if not isinstance(connection, Connection): raise ValueError( "connection must be a Connection object: {}".format(type(connection)) @@ -280,32 +288,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False): self._request = request self.arraysize = 1 - self._iterator = None - self._query = None + self._iterator: Optional[Iterator[List[Any]]] = None + self._query: Optional[trino.client.TrinoQuery] = None self._legacy_primitive_types = legacy_primitive_types - def __iter__(self): + def __iter__(self) -> Optional[Iterator[List[Any]]]: return self._iterator @property - def connection(self): + def connection(self) -> Connection: return self._connection @property - def info_uri(self): + def info_uri(self) -> Optional[str]: if self._query is not None: return self._query.info_uri return None @property - def update_type(self): + def update_type(self) -> Optional[str]: if self._query is not None: return self._query.update_type return None @property - def description(self) -> List[ColumnDescription]: - if self._query.columns is None: + def description(self) -> Optional[List[Tuple[Any, ...]]]: + if self._query is None or self._query.columns is None: return None # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ] @@ -314,7 +322,7 @@ def description(self) -> List[ColumnDescription]: ] @property - def rowcount(self): + def rowcount(self) -> int: """Not supported. Trino cannot reliablity determine the number of rows returned by an @@ -325,27 +333,21 @@ def rowcount(self): return -1 @property - def stats(self): + def stats(self) -> Optional[Dict[Any, Any]]: if self._query is not None: return self._query.stats return None @property - def query_id(self) -> Optional[str]: - if self._query is not None: - return self._query.query_id - return None - - @property - def warnings(self): + def warnings(self) -> Optional[List[Dict[Any, Any]]]: if self._query is not None: return self._query.warnings return None - def setinputsizes(self, sizes): + def setinputsizes(self, sizes: Sequence[Any]) -> None: raise trino.exceptions.NotSupportedError - def setoutputsize(self, size, column): + def setoutputsize(self, size: int, column: Optional[int]) -> None: raise trino.exceptions.NotSupportedError def _prepare_statement(self, statement: str, name: str) -> None: @@ -363,13 +365,13 @@ def _prepare_statement(self, statement: str, name: str) -> None: def _execute_prepared_statement( self, - statement_name, - params - ): + statement_name: str, + params: Any + ) -> trino.client.TrinoQuery: sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) return trino.client.TrinoQuery(self._request, sql=sql, legacy_primitive_types=self._legacy_primitive_types) - def _format_prepared_param(self, param): + def _format_prepared_param(self, param: Any) -> str: """ Formats parameters to be passed in an EXECUTE statement. @@ -451,10 +453,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None: legacy_primitive_types=self._legacy_primitive_types) query.execute() - def _generate_unique_statement_name(self): + def _generate_unique_statement_name(self) -> str: return 'st_' + uuid.uuid4().hex.replace('-', '') - def execute(self, operation, params=None): + def execute(self, operation: str, params: Optional[Any] = None) -> trino.client.TrinoResult: if params: assert isinstance(params, (list, tuple)), ( 'params must be a list or tuple containing the query ' @@ -484,7 +486,7 @@ def execute(self, operation, params=None): self._iterator = iter(self._query.execute()) return self - def executemany(self, operation, seq_of_params): + def executemany(self, operation: str, seq_of_params: Any) -> None: """ PEP-0249: Prepare a database operation (query or command) and then execute it against all parameter sequences or mappings found in the sequence seq_of_parameters. @@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]: except trino.exceptions.HttpError as err: raise trino.exceptions.OperationalError(str(err)) - def fetchmany(self, size=None) -> List[List[Any]]: + def fetchmany(self, size: Optional[int] = None) -> List[List[Any]]: """ PEP-0249: Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a list of tuples). An empty sequence is @@ -584,20 +586,20 @@ def describe(self, sql: str) -> List[DescribeOutput]: return list(map(lambda x: DescribeOutput.from_row(x), result)) - def genall(self): + def genall(self) -> trino.client.TrinoResult: return self._query.result def fetchall(self) -> List[List[Any]]: return list(self.genall()) - def cancel(self): + def cancel(self) -> None: if self._query is None: raise trino.exceptions.OperationalError( "Cancel query failed; no running query" ) self._query.cancel() - def close(self): + def close(self) -> None: self.cancel() # TODO: Cancel not only the last query executed on this cursor # but also any other outstanding queries executed through this cursor. @@ -610,19 +612,19 @@ def close(self): TimestampFromTicks = datetime.datetime.fromtimestamp -def TimeFromTicks(ticks): - return datetime.time(*datetime.localtime(ticks)[3:6]) +def TimeFromTicks(ticks: int) -> datetime.time: + return datetime.time(*time.localtime(ticks)[3:6]) -def Binary(string): +def Binary(string: str) -> bytes: return string.encode("utf-8") class DBAPITypeObject: - def __init__(self, *values): + def __init__(self, *values: str): self.values = [v.lower() for v in values] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return other.lower() in self.values From a45c0000af1aa0ccfa92fe79c5d3241943e00d75 Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Thu, 19 Jan 2023 10:30:42 +0100 Subject: [PATCH 2/5] Small fixes --- trino/client.py | 3 ++- trino/dbapi.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/trino/client.py b/trino/client.py index ac407363..6f58a9eb 100644 --- a/trino/client.py +++ b/trino/client.py @@ -753,7 +753,8 @@ def columns(self): while not self._columns and not self.finished and not self.cancelled: # Columns are not returned immediately after query is submitted. # Continue fetching data until columns information is available and push fetched rows into buffer. - self._result.rows += self.fetch() + if self._result: + self._result.rows += self.fetch() return self._columns @property diff --git a/trino/dbapi.py b/trino/dbapi.py index 206a6fe0..d538fedf 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -338,6 +338,12 @@ def stats(self) -> Optional[Dict[Any, Any]]: return self._query.stats return None + @property + def query_id(self) -> Optional[str]: + if self._query is not None: + return self._query.query_id + return None + @property def warnings(self) -> Optional[List[Dict[Any, Any]]]: if self._query is not None: @@ -505,6 +511,7 @@ def executemany(self, operation: str, seq_of_params: Any) -> None: for parameters in seq_of_params[:-1]: self.execute(operation, parameters) self.fetchall() + assert self._query is not None if self._query.update_type is None: raise NotSupportedError("Query must return update type") if seq_of_params: @@ -586,8 +593,10 @@ def describe(self, sql: str) -> List[DescribeOutput]: return list(map(lambda x: DescribeOutput.from_row(x), result)) - def genall(self) -> trino.client.TrinoResult: - return self._query.result + def genall(self) -> Any: + if self._query: + return self._query.result + return None def fetchall(self) -> List[List[Any]]: return list(self.genall()) @@ -625,6 +634,8 @@ def __init__(self, *values: str): self.values = [v.lower() for v in values] def __eq__(self, other: object) -> bool: + if not isinstance(other, str): + return NotImplemented return other.lower() in self.values From 435a30fc8cd10c43ee852c1d08912f7f5bd60bde Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Fri, 20 Jan 2023 11:39:51 +0100 Subject: [PATCH 3/5] Reformatting --- trino/client.py | 113 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 34 deletions(-) diff --git a/trino/client.py b/trino/client.py index 6f58a9eb..932d425d 100644 --- a/trino/client.py +++ b/trino/client.py @@ -47,18 +47,30 @@ from datetime import date, datetime, time, timedelta, timezone, tzinfo from decimal import Decimal from time import sleep -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import pytz import requests from pytz.tzinfo import BaseTzInfo -from tzlocal import get_localzone_name # type: ignore +from tzlocal import get_localzone_name import trino.logging from trino import constants, exceptions try: - from zoneinfo import ZoneInfo # type: ignore + from zoneinfo import ZoneInfo except ModuleNotFoundError: from backports.zoneinfo import ZoneInfo # type: ignore @@ -75,7 +87,7 @@ else: PROXIES = {} -_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') +_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$") T = TypeVar("T") @@ -461,8 +473,13 @@ def http_headers(self) -> Dict[str, str]: "{}={}".format(catalog, urllib.parse.quote(str(role))) for catalog, role in self._client_session.roles.items() ) - if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0: - headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags) + if ( + self._client_session.client_tags is not None + and len(self._client_session.client_tags) > 0 + ): + headers[constants.HEADER_CLIENT_TAGS] = ",".join( + self._client_session.client_tags + ) headers[constants.HEADER_SESSION] = ",".join( # ``name`` must not contain ``=`` @@ -486,8 +503,10 @@ def http_headers(self) -> Dict[str, str]: transaction_id = self._client_session.transaction_id headers[constants.HEADER_TRANSACTION] = transaction_id - if self._client_session.extra_credential is not None and \ - len(self._client_session.extra_credential) > 0: + if ( + self._client_session.extra_credential is not None + and len(self._client_session.extra_credential) > 0 + ): for tup in self._client_session.extra_credential: self._verify_extra_credential(tup) @@ -495,9 +514,12 @@ def http_headers(self) -> Dict[str, str]: # HTTP 1.1 section 4.2 combine multiple extra credentials into a # comma-separated value # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) - headers[constants.HEADER_EXTRA_CREDENTIAL] = \ - ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential]) + headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join( + [ + f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" + for tup in self._client_session.extra_credential + ] + ) return headers @@ -562,7 +584,12 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non while http_response is not None and http_response.is_redirect: location = http_response.headers["Location"] url = self._redirect_handler.handle(location) - logger.info("redirect %s from %s to %s", http_response.status_code, location, url) + logger.info( + "redirect %s from %s to %s", + http_response.status_code, + location, + url, + ) http_response = self._post( url, data=data, @@ -606,7 +633,7 @@ def raise_response_error(self, http_response): raise exceptions.HttpError( "error {}{}".format( http_response.status_code, - ": {}".format(http_response.content) if http_response.content else "", + ": {}".format(repr(http_response.content)) if http_response.content else "", ) ) @@ -633,14 +660,18 @@ def process(self, http_response) -> TrinoStatus: self._client_session.properties[key] = value if constants.HEADER_SET_CATALOG in http_response.headers: - self._client_session.catalog = http_response.headers[constants.HEADER_SET_CATALOG] + self._client_session.catalog = http_response.headers[ + constants.HEADER_SET_CATALOG + ] if constants.HEADER_SET_SCHEMA in http_response.headers: - self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA] + self._client_session.schema = http_response.headers[ + constants.HEADER_SET_SCHEMA + ] if constants.HEADER_SET_ROLE in http_response.headers: for key, value in get_roles_values( - http_response.headers, constants.HEADER_SET_ROLE + http_response.headers, constants.HEADER_SET_ROLE ): self._client_session.roles[key] = value @@ -676,12 +707,16 @@ def _verify_extra_credential(self, header): key = header[0] if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX.match(key): - raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'") + raise ValueError( + f"whitespace or '=' are disallowed in extra credential '{key}'" + ) try: - key.encode().decode('ascii') + key.encode().decode("ascii") except UnicodeDecodeError: - raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") + raise ValueError( + f"only ASCII characters are allowed in extra credential '{key}'" + ) class TrinoResult(object): @@ -847,7 +882,10 @@ def cancel(self) -> None: def is_finished(self) -> bool: import warnings - warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning) + + warnings.warn( + "is_finished is deprecated, use finished instead", DeprecationWarning + ) return self.finished @property @@ -910,11 +948,11 @@ class DoubleValueMapper(ValueMapper[float]): def map(self, value) -> Optional[float]: if value is None: return None - if value == 'Infinity': + if value == "Infinity": return float("inf") - if value == '-Infinity': + if value == "-Infinity": return float("-inf") - if value == 'NaN': + if value == "NaN": return float("nan") return float(value) @@ -1119,7 +1157,9 @@ def __init__(self, mappers: List[ValueMapper[Any]]): def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: if values is None: return None - return tuple(self.mappers[index].map(value) for index, value in enumerate(values)) + return tuple( + self.mappers[index].map(value) for index, value in enumerate(values) + ) class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): @@ -1131,7 +1171,8 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: if values is None: return None return { - self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() + self.key_mapper.map(key): self.value_mapper.map(value) + for key, value in values.items() } @@ -1151,6 +1192,7 @@ class RowMapperFactory: lambda functions (one for each column) which will process a data value and returns a RowMapper instance which will process rows of data """ + NO_OP_ROW_MAPPER = NoOpRowMapper() def create(self, columns, legacy_primitive_types): @@ -1163,19 +1205,22 @@ def create(self, columns, legacy_primitive_types): def _create_value_mapper(self, column) -> ValueMapper: col_type = column['rawType'] - if col_type == 'array': - value_mapper = self._create_value_mapper(column['arguments'][0]['value']) + if col_type == "array": + value_mapper = self._create_value_mapper(column["arguments"][0]["value"]) return ArrayValueMapper(value_mapper) - elif col_type == 'row': - mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']] + elif col_type == "row": + mappers = [ + self._create_value_mapper(arg["value"]["typeSignature"]) + for arg in column["arguments"] + ] return RowValueMapper(mappers) - elif col_type == 'map': - key_mapper = self._create_value_mapper(column['arguments'][0]['value']) - value_mapper = self._create_value_mapper(column['arguments'][1]['value']) + elif col_type == "map": + key_mapper = self._create_value_mapper(column["arguments"][0]["value"]) + value_mapper = self._create_value_mapper(column["arguments"][1]["value"]) return MapValueMapper(key_mapper, value_mapper) - elif col_type.startswith('decimal'): + elif col_type.startswith("decimal"): return DecimalValueMapper() - elif col_type.startswith('double') or col_type.startswith('real'): + elif col_type.startswith("double") or col_type.startswith("real"): return DoubleValueMapper() elif col_type.startswith('timestamp') and 'with time zone' in col_type: return TimestampWithTimeZoneValueMapper(self._get_precision(column)) From 02be1916d76e7b2e916a9bb1019dbda464d29b4e Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Fri, 20 Jan 2023 11:43:01 +0100 Subject: [PATCH 4/5] Enable mypy checks for client.py --- setup.cfg | 2 +- trino/client.py | 267 +++++++++++++++++++++++++++----------------- trino/dbapi.py | 43 ++++--- trino/exceptions.py | 2 +- 4 files changed, 193 insertions(+), 121 deletions(-) diff --git a/setup.cfg b/setup.cfg index b473f771..04a40888 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,5 @@ ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -[mypy-tests.*,trino.client,trino.sqlalchemy.*,trino.dbapi] +[mypy-tests.*,trino.sqlalchemy.*] ignore_errors = true diff --git a/trino/client.py b/trino/client.py index 932d425d..9c02d6d0 100644 --- a/trino/client.py +++ b/trino/client.py @@ -44,6 +44,7 @@ import threading import urllib.parse import warnings +from abc import ABC, abstractmethod from datetime import date, datetime, time, timedelta, timezone, tzinfo from decimal import Decimal from time import sleep @@ -59,6 +60,7 @@ Type, TypeVar, Union, + overload, ) import pytz @@ -141,14 +143,14 @@ def __init__( catalog: Optional[str] = None, schema: Optional[str] = None, source: Optional[str] = None, - properties: Dict[str, str] = None, - headers: Dict[str, str] = None, - transaction_id: str = None, - extra_credential: List[Tuple[str, str]] = None, - client_tags: List[str] = None, - roles: Dict[str, str] = None, - timezone: str = None, - ): + properties: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + transaction_id: Optional[str] = None, + extra_credential: Optional[List[Tuple[str, str]]] = None, + client_tags: Optional[List[str]] = None, + roles: Optional[Dict[str, str]] = None, + timezone: Optional[str] = None, + ) -> None: self._user = user self._catalog = catalog self._schema = schema @@ -166,90 +168,90 @@ def __init__( ZoneInfo(timezone) @property - def user(self): + def user(self) -> Optional[str]: return self._user @property - def catalog(self): + def catalog(self) -> Optional[str]: with self._object_lock: return self._catalog @catalog.setter - def catalog(self, catalog): + def catalog(self, catalog: Optional[str]) -> None: with self._object_lock: self._catalog = catalog @property - def schema(self): + def schema(self) -> Optional[str]: with self._object_lock: return self._schema @schema.setter - def schema(self, schema): + def schema(self, schema: Optional[str]) -> None: with self._object_lock: self._schema = schema @property - def source(self): + def source(self) -> Optional[str]: return self._source @property - def properties(self): + def properties(self) -> Dict[str, str]: with self._object_lock: return self._properties @properties.setter - def properties(self, properties): + def properties(self, properties: Dict[str, str]) -> None: with self._object_lock: self._properties = properties @property - def headers(self): + def headers(self) -> Dict[str, str]: return self._headers @property - def transaction_id(self): + def transaction_id(self) -> Optional[str]: with self._object_lock: return self._transaction_id @transaction_id.setter - def transaction_id(self, transaction_id): + def transaction_id(self, transaction_id: Optional[str]) -> None: with self._object_lock: self._transaction_id = transaction_id @property - def extra_credential(self): + def extra_credential(self) -> Optional[List[Tuple[str, str]]]: return self._extra_credential @property - def client_tags(self): + def client_tags(self) -> List[str]: return self._client_tags @property - def roles(self): + def roles(self) -> Dict[str, str]: with self._object_lock: return self._roles @roles.setter - def roles(self, roles): + def roles(self, roles: Dict[str, str]) -> None: with self._object_lock: self._roles = roles @property - def prepared_statements(self): + def prepared_statements(self) -> Dict[str, str]: return self._prepared_statements @prepared_statements.setter - def prepared_statements(self, prepared_statements): + def prepared_statements(self, prepared_statements: Dict[str, str]) -> None: with self._object_lock: self._prepared_statements = prepared_statements @property - def timezone(self): + def timezone(self) -> Optional[str]: with self._object_lock: return self._timezone - def _format_roles(self, roles): + def _format_roles(self, roles: Dict[str, str]) -> Dict[str, str]: formatted_roles = {} for catalog, role in roles.items(): is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None @@ -264,21 +266,25 @@ def _format_roles(self, roles): formatted_roles[catalog] = f"ROLE{{{role}}}" return formatted_roles - def __getstate__(self): + def __getstate__(self) -> Dict[str, str]: state = self.__dict__.copy() del state["_object_lock"] return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, str]) -> None: self.__dict__.update(state) self._object_lock = threading.Lock() -def get_header_values(headers, header): +def get_header_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[str]: return [val.strip() for val in headers[header].split(",")] -def get_session_property_values(headers, header): +def get_session_property_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -286,7 +292,9 @@ def get_session_property_values(headers, header): ] -def get_prepared_statement_values(headers, header): +def get_prepared_statement_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -294,7 +302,9 @@ def get_prepared_statement_values(headers, header): ] -def get_roles_values(headers, header): +def get_roles_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -303,7 +313,17 @@ def get_roles_values(headers, header): class TrinoStatus(object): - def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None): + def __init__( + self, + id: str, + stats: Dict[str, Any], + warnings: List[Any], + info_uri: str, + next_uri: Optional[str], + update_type: Any, + rows: List[Any], + columns: Optional[List[str]] = None, + ) -> None: self.id = id self.stats = stats self.warnings = warnings @@ -313,7 +333,7 @@ def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, c self.rows = rows self.columns = columns - def __repr__(self): + def __repr__(self) -> str: return ( "TrinoStatus(" "id={}, stats={{...}}, warnings={}, info_uri={}, next_uri={}, rows=" @@ -329,15 +349,19 @@ def __repr__(self): class _DelayExponential(object): def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + self, + base: float = 0.1, # 100ms + exponent: int = 2, + jitter: bool = True, + max_delay: int = 2 * 3600, # 2 hours + ) -> None: self._base = base self._exponent = exponent self._jitter = jitter self._max_delay = max_delay - def __call__(self, attempt): - delay = float(self._base) * (self._exponent ** attempt) + def __call__(self, attempt: int) -> float: + delay = float(self._base) * (self._exponent**attempt) if self._jitter: delay *= random.random() delay = min(float(self._max_delay), delay) @@ -346,11 +370,15 @@ def __call__(self, attempt): class _RetryWithExponentialBackoff(object): def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + self, + base: float = 0.1, # 100ms + exponent: int = 2, + jitter: bool = True, + max_delay: int = 2 * 3600, # 2 hours + ) -> None: self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) - def retry(self, func, args, kwargs, err, attempt): + def retry(self, attempt: int) -> None: delay = self._get_delay(attempt) sleep(delay) @@ -409,12 +437,12 @@ def __init__( port: int, client_session: ClientSession, http_session: Any = None, - http_scheme: str = None, + http_scheme: Optional[str] = None, auth: Optional[Any] = constants.DEFAULT_AUTH, redirect_handler: Any = None, max_attempts: int = MAX_ATTEMPTS, request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT, - handle_retry=_RetryWithExponentialBackoff(), + handle_retry: _RetryWithExponentialBackoff = _RetryWithExponentialBackoff(), verify: bool = True, ) -> None: self._client_session = client_session @@ -450,15 +478,15 @@ def __init__( self.max_attempts = max_attempts @property - def transaction_id(self): + def transaction_id(self) -> Optional[str]: return self._client_session.transaction_id @transaction_id.setter - def transaction_id(self, value): + def transaction_id(self, value: Optional[str]) -> None: self._client_session.transaction_id = value @property - def http_headers(self) -> Dict[str, str]: + def http_headers(self) -> Dict[str, Optional[str]]: headers = {} headers[constants.HEADER_CATALOG] = self._client_session.catalog @@ -528,7 +556,7 @@ def max_attempts(self) -> int: return self._max_attempts @max_attempts.setter - def max_attempts(self, value) -> None: + def max_attempts(self, value: int) -> None: self._max_attempts = value if value == 1: # No retry self._get = self._http_session.get @@ -550,7 +578,7 @@ def max_attempts(self, value) -> None: self._post = with_retry(self._http_session.post) self._delete = with_retry(self._http_session.delete) - def get_url(self, path) -> str: + def get_url(self, path: str) -> str: return "{protocol}://{host}:{port}{path}".format( protocol=self._http_scheme, host=self._host, port=self._port, path=path ) @@ -563,7 +591,9 @@ def statement_url(self) -> str: def next_uri(self) -> Optional[str]: return self._next_uri - def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None): + def post( + self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None + ) -> requests.Response: data = sql.encode("utf-8") # Deep copy of the http_headers dict since they may be modified for this # request by the provided additional_http_headers @@ -600,7 +630,7 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non ) return http_response - def get(self, url: str): + def get(self, url: Optional[str]) -> requests.Response: return self._get( url, headers=self.http_headers, @@ -608,10 +638,12 @@ def get(self, url: str): proxies=PROXIES, ) - def delete(self, url): + def delete(self, url: str) -> requests.Response: return self._delete(url, timeout=self._request_timeout, proxies=PROXIES) - def _process_error(self, error, query_id): + def _process_error( + self, error: Dict[str, Any], query_id: str + ) -> Union[exceptions.TrinoUserError, exceptions.TrinoQueryError]: error_type = error["errorType"] if error_type == "EXTERNAL": raise exceptions.TrinoExternalError(error, query_id) @@ -620,7 +652,7 @@ def _process_error(self, error, query_id): return exceptions.TrinoQueryError(error, query_id) - def raise_response_error(self, http_response): + def raise_response_error(self, http_response: requests.Response) -> None: if http_response.status_code == 502: raise exceptions.Http502Error("error 502: bad gateway") @@ -637,7 +669,7 @@ def raise_response_error(self, http_response): ) ) - def process(self, http_response) -> TrinoStatus: + def process(self, http_response: requests.Response) -> TrinoStatus: if not http_response.ok: self.raise_response_error(http_response) @@ -700,7 +732,7 @@ def process(self, http_response) -> TrinoStatus: columns=response.get("columns"), ) - def _verify_extra_credential(self, header): + def _verify_extra_credential(self, header: Tuple[str, str]) -> None: """ Verifies that key has ASCII only and non-whitespace characters. """ @@ -727,25 +759,25 @@ class TrinoResult(object): https://docs.python.org/3/library/stdtypes.html#generator-types """ - def __init__(self, query, rows: List[Any]): + def __init__(self, query: Any, rows: List[Any]) -> None: self._query = query # Initial rows from the first POST request self._rows = rows self._rownumber = 0 @property - def rows(self): + def rows(self) -> List[Any]: return self._rows @rows.setter - def rows(self, rows): + def rows(self, rows: List[Any]) -> None: self._rows = rows @property def rownumber(self) -> int: return self._rownumber - def __iter__(self): + def __iter__(self) -> Generator[Any, None, None]: # A query only transitions to a FINISHED state when the results are fully consumed: # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. while not self._query.finished or self._rows is not None: @@ -780,10 +812,10 @@ def __init__( self._sql = sql self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types - self._row_mapper: Optional[RowMapper] = None + self._row_mapper: Optional[RowMapperBase] = None @property - def columns(self): + def columns(self) -> Any: if self.query_id: while not self._columns and not self.finished and not self.cancelled: # Columns are not returned immediately after query is submitted. @@ -793,26 +825,28 @@ def columns(self): return self._columns @property - def stats(self): + def stats(self) -> Dict[Any, Any]: return self._stats @property - def update_type(self): + def update_type(self) -> Any: return self._update_type @property - def warnings(self): + def warnings(self) -> List[Dict[Any, Any]]: return self._warnings @property - def result(self): + def result(self) -> Optional[TrinoResult]: return self._result @property - def info_uri(self): + def info_uri(self) -> Optional[str]: return self._info_uri - def execute(self, additional_http_headers=None) -> TrinoResult: + def execute( + self, additional_http_headers: Optional[Dict[str, Any]] = None + ) -> TrinoResult: """Initiate a Trino query by sending the SQL statement This is the first HTTP request sent to the coordinator. @@ -841,7 +875,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult: self._result.rows += self.fetch() return self._result - def _update_state(self, status): + def _update_state(self, status: TrinoStatus) -> None: self._stats.update(status.stats) self._update_type = status.update_type if not self._row_mapper and status.columns: @@ -897,23 +931,30 @@ def cancelled(self) -> bool: return self._cancelled -def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): - def wrapper(func): +def _retry_with( + handle_retry: _RetryWithExponentialBackoff, + handled_exceptions: Tuple[ + Type[requests.exceptions.ConnectionError], Type[requests.exceptions.Timeout] + ], + conditions: Tuple[Callable[[Any], bool]], + max_attempts: int, +) -> Callable[[Any], Any]: + def wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @functools.wraps(func) - def decorated(*args, **kwargs): + def decorated(*args: Any, **kwargs: Any) -> Optional[Any]: error = None result = None for attempt in range(1, max_attempts + 1): try: result = func(*args, **kwargs) if any(guard(result) for guard in conditions): - handle_retry.retry(func, args, kwargs, None, attempt) + handle_retry.retry(attempt) continue return result except Exception as err: error = err if any(isinstance(err, exc) for exc in handled_exceptions): - handle_retry.retry(func, args, kwargs, err, attempt) + handle_retry.retry(attempt) continue break logger.info("failed after %s attempts", attempt) @@ -933,19 +974,19 @@ def map(self, value: Any) -> Optional[T]: class NoOpValueMapper(ValueMapper[Any]): - def map(self, value) -> Optional[Any]: + def map(self, value: Optional[Any]) -> Optional[Any]: return value class DecimalValueMapper(ValueMapper[Decimal]): - def map(self, value) -> Optional[Decimal]: + def map(self, value: Optional[Any]) -> Optional[Decimal]: if value is None: return None return Decimal(value) class DoubleValueMapper(ValueMapper[float]): - def map(self, value) -> Optional[float]: + def map(self, value: Optional[str]) -> Optional[float]: if value is None: return None if value == "Infinity": @@ -973,7 +1014,7 @@ def _fraction_to_decimal(fractional_str: str) -> Decimal: class TemporalType(Generic[PythonTemporalType], metaclass=abc.ABCMeta): - def __init__(self, whole_python_temporal_value: PythonTemporalType, remaining_fractional_seconds: Decimal): + def __init__(self, whole_python_temporal_value: PythonTemporalType, remaining_fractional_seconds: Decimal) -> None: self._whole_python_temporal_value = whole_python_temporal_value self._remaining_fractional_seconds = remaining_fractional_seconds @@ -985,7 +1026,7 @@ def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> Temporal def to_python_type(self) -> PythonTemporalType: pass - def round_to(self, precision: int) -> TemporalType: + def round_to(self, precision: int) -> TemporalType[Any]: """ Python datetime and time only support up to microsecond precision In case the supplied value exceeds the specified precision, @@ -1066,11 +1107,11 @@ def normalize(self, value: datetime) -> datetime: class TimeValueMapper(ValueMapper[time]): - def __init__(self, precision): + def __init__(self, precision: int) -> None: self.time_default_size = 8 # size of 'HH:MM:SS' self.precision = precision - def map(self, value) -> Optional[time]: + def map(self, value: Optional[str]) -> Optional[time]: if value is None: return None whole_python_temporal_value = value[:self.time_default_size] @@ -1085,7 +1126,7 @@ def _add_second(self, time_value: time) -> time: class TimeWithTimeZoneValueMapper(TimeValueMapper): - def map(self, value) -> Optional[time]: + def map(self, value: Optional[str]) -> Optional[time]: if value is None: return None whole_python_temporal_value = value[:self.time_default_size] @@ -1098,18 +1139,18 @@ def map(self, value) -> Optional[time]: class DateValueMapper(ValueMapper[date]): - def map(self, value) -> Optional[date]: + def map(self, value: Optional[str]) -> Optional[date]: if value is None: return None return date.fromisoformat(value) class TimestampValueMapper(ValueMapper[datetime]): - def __init__(self, precision): + def __init__(self, precision: int) -> None: self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds) self.precision = precision - def map(self, value) -> Optional[datetime]: + def map(self, value: Optional[str]) -> Optional[datetime]: if value is None: return None whole_python_temporal_value = value[:self.datetime_default_size] @@ -1121,7 +1162,7 @@ def map(self, value) -> Optional[datetime]: class TimestampWithTimeZoneValueMapper(TimestampValueMapper): - def map(self, value) -> Optional[datetime]: + def map(self, value: Optional[str]) -> Optional[datetime]: if value is None: return None datetime_with_fraction, timezone_part = value.rsplit(' ', 1) @@ -1134,27 +1175,27 @@ def map(self, value) -> Optional[datetime]: class BinaryValueMapper(ValueMapper[bytes]): - def map(self, value) -> Optional[bytes]: + def map(self, value: Optional[str]) -> Optional[bytes]: if value is None: return None return base64.b64decode(value.encode("utf8")) class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): - def __init__(self, mapper: ValueMapper[Any]): + def __init__(self, mapper: ValueMapper[Any]) -> None: self.mapper = mapper - def map(self, values: List[Any]) -> Optional[List[Any]]: + def map(self, values: Optional[List[Any]]) -> Optional[List[Any]]: if values is None: return None return [self.mapper.map(value) for value in values] class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): - def __init__(self, mappers: List[ValueMapper[Any]]): + def __init__(self, mappers: List[ValueMapper[Any]]) -> None: self.mappers = mappers - def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: + def map(self, values: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]: if values is None: return None return tuple( @@ -1163,7 +1204,7 @@ def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): - def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): + def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]) -> None: self.key_mapper = key_mapper self.value_mapper = value_mapper @@ -1176,13 +1217,27 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: } -class NoOpRowMapper: +class RowMapperBase(ABC): + @overload + def map(self, rows: List[Any]) -> List[Any]: + ... + + @overload + def map(self, rows: Any) -> Any: + ... + + @abstractmethod + def map(self, rows: Any) -> Any: + pass + + +class NoOpRowMapper(RowMapperBase): """ No-op RowMapper which does not perform any transformation Used when legacy_primitive_types is False. """ - def map(self, rows): + def map(self, rows: Any) -> Any: return rows @@ -1195,15 +1250,15 @@ class RowMapperFactory: NO_OP_ROW_MAPPER = NoOpRowMapper() - def create(self, columns, legacy_primitive_types): + def create(self, columns: Any, legacy_primitive_types: bool) -> Optional[RowMapperBase]: assert columns is not None if not legacy_primitive_types: return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns]) return RowMapperFactory.NO_OP_ROW_MAPPER - def _create_value_mapper(self, column) -> ValueMapper: - col_type = column['rawType'] + def _create_value_mapper(self, column: Any) -> ValueMapper[Any]: + col_type = column["rawType"] if col_type == "array": value_mapper = self._create_value_mapper(column["arguments"][0]["value"]) @@ -1237,29 +1292,33 @@ def _create_value_mapper(self, column) -> ValueMapper: else: return NoOpValueMapper() - def _get_precision(self, column: Dict[str, Any]): + def _get_precision(self, column: Dict[str, Any]) -> int: args = column['arguments'] if len(args) == 0: return 3 return args[0]['value'] -class RowMapper: +class RowMapper(RowMapperBase): """ Maps a row of data given a list of mapping functions """ - def __init__(self, columns): + + def __init__(self, columns: List[Any]) -> None: self.columns = columns - def map(self, rows): + def map(self, rows: List[Any]) -> List[Any]: if len(self.columns) == 0: return rows return [self._map_row(row) for row in rows] - def _map_row(self, row): - return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)] + def _map_row(self, row: str) -> List[Optional[T]]: + return [ + self._map_value(value, self.columns[index]) + for index, value in enumerate(row) + ] - def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]: + def _map_value(self, value: Any, value_mapper: ValueMapper[T]) -> Optional[T]: try: return value_mapper.map(value) except ValueError as e: diff --git a/trino/dbapi.py b/trino/dbapi.py index d538fedf..88505631 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -17,6 +17,8 @@ Fetch methods returns rows as a list of lists on purpose to let the caller decide to convert then to a list of tuples. """ +from __future__ import annotations + import binascii import datetime import math @@ -24,7 +26,18 @@ import uuid from decimal import Decimal from types import TracebackType -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + Union, +) import trino.client import trino.exceptions @@ -74,7 +87,7 @@ logger = trino.logging.get_logger(__name__) -def connect(*args: Any, **kwargs: Any) -> trino.dbapi.Connection: +def connect(*args: Any, **kwargs: Any) -> Connection: """Constructor for creating a connection to the database. See class :py:class:`Connection` for arguments. @@ -109,12 +122,12 @@ def __init__( max_attempts: int = constants.DEFAULT_MAX_ATTEMPTS, request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT, isolation_level: IsolationLevel = IsolationLevel.AUTOCOMMIT, - verify: Union[bool | str] = True, + verify: Union[bool, str] = True, http_session: Optional[trino.client.TrinoRequest.http.Session] = None, client_tags: Optional[List[str]] = None, - legacy_primitive_types: Optional[bool] = False, + legacy_primitive_types: bool = False, roles: Optional[Dict[str, str]] = None, - timezone=None, + timezone: Optional[str] = None, ) -> None: self.host = host self.port = port @@ -164,7 +177,7 @@ def isolation_level(self) -> IsolationLevel: def transaction(self) -> Optional[Transaction]: return self._transaction - def __enter__(self) -> object: + def __enter__(self) -> Connection: return self def __exit__(self, @@ -212,7 +225,7 @@ def _create_request(self) -> trino.client.TrinoRequest: self.request_timeout, ) - def cursor(self, legacy_primitive_types: bool = None) -> 'trino.dbapi.Cursor': + def cursor(self, legacy_primitive_types: Optional[bool] = None) -> Cursor: """Return a new :py:class:`Cursor` object using the connection.""" if self.isolation_level != IsolationLevel.AUTOCOMMIT: if self.transaction is None: @@ -239,21 +252,21 @@ class DescribeOutput(NamedTuple): aliased: bool @classmethod - def from_row(cls, row: List[Any]): + def from_row(cls, row: List[Any]) -> DescribeOutput: return cls(*row) class ColumnDescription(NamedTuple): name: str type_code: int - display_size: int + display_size: Optional[int] internal_size: int precision: int scale: int - null_ok: bool + null_ok: Optional[bool] @classmethod - def from_column(cls, column: Dict[str, Any]): + def from_column(cls, column: Dict[str, Any]) -> ColumnDescription: type_signature = column["typeSignature"] raw_type = type_signature["rawType"] arguments = type_signature["arguments"] @@ -312,7 +325,7 @@ def update_type(self) -> Optional[str]: return None @property - def description(self) -> Optional[List[Tuple[Any, ...]]]: + def description(self) -> Optional[List[ColumnDescription]]: if self._query is None or self._query.columns is None: return None @@ -462,7 +475,7 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None: def _generate_unique_statement_name(self) -> str: return 'st_' + uuid.uuid4().hex.replace('-', '') - def execute(self, operation: str, params: Optional[Any] = None) -> trino.client.TrinoResult: + def execute(self, operation: str, params: Optional[Any] = None) -> Cursor: if params: assert isinstance(params, (list, tuple)), ( 'params must be a list or tuple containing the query ' @@ -492,7 +505,7 @@ def execute(self, operation: str, params: Optional[Any] = None) -> trino.client. self._iterator = iter(self._query.execute()) return self - def executemany(self, operation: str, seq_of_params: Any) -> None: + def executemany(self, operation: str, seq_of_params: Any) -> Cursor: """ PEP-0249: Prepare a database operation (query or command) and then execute it against all parameter sequences or mappings found in the sequence seq_of_parameters. @@ -598,7 +611,7 @@ def genall(self) -> Any: return self._query.result return None - def fetchall(self) -> List[List[Any]]: + def fetchall(self) -> Optional[List[List[Any]]]: return list(self.genall()) def cancel(self) -> None: diff --git a/trino/exceptions.py b/trino/exceptions.py index d48fc9ef..ccbd19b8 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -72,7 +72,7 @@ class TrinoDataError(NotSupportedError): class TrinoQueryError(Error): - def __init__(self, error: Dict[str, Any], query_id: Optional[str] = None) -> None: + def __init__(self, error: Any, query_id: Optional[str] = None) -> None: self._error = error self._query_id = query_id From 3de613f7314803e42544fd886a73b9212e6a3097 Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Wed, 18 Jan 2023 19:34:56 +0100 Subject: [PATCH 5/5] Small fix in trino.exceptions.TrinoQueryError --- trino/exceptions.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/trino/exceptions.py b/trino/exceptions.py index ccbd19b8..ca083d08 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -14,7 +14,7 @@ This module defines exceptions for Trino operations. It follows the structure defined in pep-0249. """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import trino.logging @@ -72,38 +72,44 @@ class TrinoDataError(NotSupportedError): class TrinoQueryError(Error): - def __init__(self, error: Any, query_id: Optional[str] = None) -> None: - self._error = error + def __init__(self, error: Union[Dict[str, Any], str], query_id: Optional[str] = None) -> None: + if isinstance(error, dict): + self._error = error + elif isinstance(error, str): + self._error = {"message": error} self._query_id = query_id @property def error_code(self) -> Optional[int]: - return self._error.get("errorCode", None) + return self._error.get("errorCode") @property def error_name(self) -> Optional[str]: - return self._error.get("errorName", None) + return self._error.get("errorName") @property def error_type(self) -> Optional[str]: - return self._error.get("errorType", None) + return self._error.get("errorType") @property def error_exception(self) -> Optional[str]: - return self.failure_info.get("type", None) if self.failure_info else None + return self.failure_info.get("type") if self.failure_info else None @property def failure_info(self) -> Optional[Dict[str, Any]]: - return self._error.get("failureInfo", None) + return self._error.get("failureInfo") @property def message(self) -> str: return self._error.get("message", "Trino did not return an error message") @property - def error_location(self) -> Tuple[int, int]: - location = self._error["errorLocation"] - return (location["lineNumber"], location["columnNumber"]) + def error_location(self) -> Optional[Tuple[int, int]]: + location = self._error.get("errorLocation") + if location: + return (location["lineNumber"], location["columnNumber"]) + else: + return None @property def query_id(self) -> Optional[str]: