diff --git a/awswrangler/athena/_cache.py b/awswrangler/athena/_cache.py index 1acbb032f..2e9252baf 100644 --- a/awswrangler/athena/_cache.py +++ b/awswrangler/athena/_cache.py @@ -7,7 +7,7 @@ import re import threading from heapq import heappop, heappush -from typing import TYPE_CHECKING, Any, Match, NamedTuple +from typing import TYPE_CHECKING, Match, NamedTuple import boto3 @@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple): has_valid_cache: bool file_format: str | None = None query_execution_id: str | None = None - query_execution_payload: dict[str, Any] | None = None + query_execution_payload: "QueryExecutionTypeDef" | None = None class _LocalMetadataCacheManager: def __init__(self) -> None: self._lock: threading.Lock = threading.Lock() - self._cache: dict[str, Any] = {} + self._cache: dict[str, "QueryExecutionTypeDef"] = {} self._pqueue: list[tuple[datetime.datetime, str]] = [] self._max_cache_size = 100 - def update_cache(self, items: list[dict[str, Any]]) -> None: + def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None: """ Update the local metadata cache with new query metadata. Parameters ---------- - items : List[Dict[str, Any]] + items List of query execution metadata which is returned by boto3 `batch_get_query_execution()`. """ with self._lock: @@ -62,7 +62,7 @@ def update_cache(self, items: list[dict[str, Any]]) -> None: heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"])) self._cache[item["QueryExecutionId"]] = item - def sorted_successful_generator(self) -> list[dict[str, Any]]: + def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]: """ Sorts the entries in the local cache based on query Completion DateTime. @@ -70,10 +70,9 @@ def sorted_successful_generator(self) -> list[dict[str, Any]]: Returns ------- - List[Dict[str, Any]] Returns successful DDL and DML queries sorted by query completion time. """ - filtered: list[dict[str, Any]] = [] + filtered: list["QueryExecutionTypeDef"] = [] for query in self._cache.values(): if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]): filtered.append(query) @@ -115,9 +114,7 @@ def _compare_query_string(sql: str, other: str) -> bool: comparison_query = _prepare_query_string_for_comparison(query_string=other) _logger.debug("sql: %s", sql) _logger.debug("comparison_query: %s", comparison_query) - if sql == comparison_query: - return True - return False + return sql == comparison_query def _prepare_query_string_for_comparison(query_string: str) -> str: @@ -135,7 +132,7 @@ def _get_last_query_infos( max_remote_cache_entries: int, boto3_session: boto3.Session | None = None, workgroup: str | None = None, -) -> list[dict[str, Any]]: +) -> list["QueryExecutionTypeDef"]: """Return an iterator of `query_execution_info`s run by the workgroup in Athena.""" client_athena = _utils.client(service_name="athena", session=boto3_session) page_size = 50 @@ -160,7 +157,7 @@ def _get_last_query_infos( QueryExecutionIds=uncached_ids[i : i + page_size], ).get("QueryExecutions") ) - _cache_manager.update_cache(new_execution_data) # type: ignore[arg-type] + _cache_manager.update_cache(new_execution_data) return _cache_manager.sorted_successful_generator() diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 0286ce5de..ef81a34c9 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1045,7 +1045,10 @@ def read_sql_query( # Substitute query parameters if applicable sql, execution_params = _apply_formatter(sql, params, paramstyle) - if not client_request_token: + if not client_request_token and paramstyle != "qmark": + # For paramstyle=="qmark", we will need to use Athena's caching option. + # The issue is that when describing an Athena execution, the API does not return + # the parameters that were used. cache_info: _CacheInfo = _check_for_cached_results( sql=sql, boto3_session=boto3_session, diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 6d71c10e3..6b6e54da3 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -35,6 +35,7 @@ from ._cache import _cache_manager, _LocalMetadataCacheManager if TYPE_CHECKING: + from mypy_boto3_athena.type_defs import QueryExecutionTypeDef from mypy_boto3_glue.type_defs import ColumnOutputTypeDef _QUERY_FINAL_STATES: list[str] = ["FAILED", "SUCCEEDED", "CANCELLED"] @@ -53,7 +54,7 @@ class _QueryMetadata(NamedTuple): binaries: list[str] output_location: str | None manifest_location: str | None - raw_payload: dict[str, Any] + raw_payload: "QueryExecutionTypeDef" class _WorkGroupConfig(NamedTuple): @@ -214,7 +215,7 @@ def _get_query_metadata( query_execution_id: str, boto3_session: boto3.Session | None = None, categories: list[str] | None = None, - query_execution_payload: dict[str, Any] | None = None, + query_execution_payload: "QueryExecutionTypeDef" | None = None, metadata_cache_manager: _LocalMetadataCacheManager | None = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, execution_params: list[str] | None = None, @@ -225,12 +226,15 @@ def _get_query_metadata( if query_execution_payload["Status"]["State"] != "SUCCEEDED": reason: str = query_execution_payload["Status"]["StateChangeReason"] raise exceptions.QueryFailed(f"Query error: {reason}") - _query_execution_payload: dict[str, Any] = query_execution_payload + _query_execution_payload = query_execution_payload else: - _query_execution_payload = _executions.wait_query( - query_execution_id=query_execution_id, - boto3_session=boto3_session, - athena_query_wait_polling_delay=athena_query_wait_polling_delay, + _query_execution_payload = cast( + "QueryExecutionTypeDef", + _executions.wait_query( + query_execution_id=query_execution_id, + boto3_session=boto3_session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ), ) cols_types: dict[str, str] = get_query_columns_types( query_execution_id=query_execution_id, boto3_session=boto3_session @@ -266,8 +270,8 @@ def _get_query_metadata( if "ResultConfiguration" in _query_execution_payload: output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation") - athena_statistics: dict[str, int | str] = _query_execution_payload.get("Statistics", {}) - manifest_location: str | None = str(athena_statistics.get("DataManifestLocation")) + athena_statistics = _query_execution_payload.get("Statistics", {}) + manifest_location: str | None = athena_statistics.get("DataManifestLocation") if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager: metadata_cache_manager.update_cache(items=[_query_execution_payload]) diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index cc0a5029f..d747ae001 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -461,6 +461,65 @@ def test_athena_paramstyle_qmark_parameters( assert len(df_out) == 1 +@pytest.mark.parametrize( + "ctas_approach,unload_approach", + [ + pytest.param(False, False, id="regular"), + pytest.param(True, False, id="ctas"), + pytest.param(False, True, id="unload"), + ], +) +def test_athena_paramstyle_qmark_skip_caching( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + ctas_approach: bool, + unload_approach: bool, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + df_out = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + database=glue_database, + ctas_approach=ctas_approach, + unload_approach=unload_approach, + workgroup=workgroup0, + params=["Washington"], + paramstyle="qmark", + keep_files=False, + s3_output=path2, + athena_cache_settings={"max_cache_seconds": 300}, + ) + + assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington" + + df_out = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + database=glue_database, + ctas_approach=ctas_approach, + unload_approach=unload_approach, + workgroup=workgroup0, + params=["Seattle"], + paramstyle="qmark", + keep_files=False, + s3_output=path2, + athena_cache_settings={"max_cache_seconds": 300}, + ) + + assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle" + + def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0): wr.s3.to_parquet( df=get_df(),