From 53c0c48987a039f75271202015f7303fd78134a5 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 9 Sep 2024 17:32:52 -0500 Subject: [PATCH 1/4] fix qmark cache issue in Athena --- awswrangler/athena/_cache.py | 42 ++++++++++++++++++----------- awswrangler/athena/_read.py | 1 + awswrangler/athena/_utils.py | 22 +++++++++------- tests/unit/test_athena.py | 51 ++++++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 25 deletions(-) diff --git a/awswrangler/athena/_cache.py b/awswrangler/athena/_cache.py index 1acbb032f..44425fefa 100644 --- a/awswrangler/athena/_cache.py +++ b/awswrangler/athena/_cache.py @@ -7,7 +7,7 @@ import re import threading from heapq import heappop, heappush -from typing import TYPE_CHECKING, Any, Match, NamedTuple +from typing import TYPE_CHECKING, Match, NamedTuple import boto3 @@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple): has_valid_cache: bool file_format: str | None = None query_execution_id: str | None = None - query_execution_payload: dict[str, Any] | None = None + query_execution_payload: "QueryExecutionTypeDef" | None = None class _LocalMetadataCacheManager: def __init__(self) -> None: self._lock: threading.Lock = threading.Lock() - self._cache: dict[str, Any] = {} + self._cache: dict[str, "QueryExecutionTypeDef"] = {} self._pqueue: list[tuple[datetime.datetime, str]] = [] self._max_cache_size = 100 - def update_cache(self, items: list[dict[str, Any]]) -> None: + def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None: """ Update the local metadata cache with new query metadata. Parameters ---------- - items : List[Dict[str, Any]] + items List of query execution metadata which is returned by boto3 `batch_get_query_execution()`. """ with self._lock: @@ -62,7 +62,7 @@ def update_cache(self, items: list[dict[str, Any]]) -> None: heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"])) self._cache[item["QueryExecutionId"]] = item - def sorted_successful_generator(self) -> list[dict[str, Any]]: + def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]: """ Sorts the entries in the local cache based on query Completion DateTime. @@ -70,10 +70,9 @@ def sorted_successful_generator(self) -> list[dict[str, Any]]: Returns ------- - List[Dict[str, Any]] Returns successful DDL and DML queries sorted by query completion time. """ - filtered: list[dict[str, Any]] = [] + filtered: list["QueryExecutionTypeDef"] = [] for query in self._cache.values(): if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]): filtered.append(query) @@ -111,13 +110,13 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None: return None -def _compare_query_string(sql: str, other: str) -> bool: +def _compare_query_string( + sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None +) -> bool: comparison_query = _prepare_query_string_for_comparison(query_string=other) _logger.debug("sql: %s", sql) _logger.debug("comparison_query: %s", comparison_query) - if sql == comparison_query: - return True - return False + return sql == comparison_query and sql_params == other_params def _prepare_query_string_for_comparison(query_string: str) -> str: @@ -135,7 +134,7 @@ def _get_last_query_infos( max_remote_cache_entries: int, boto3_session: boto3.Session | None = None, workgroup: str | None = None, -) -> list[dict[str, Any]]: +) -> list["QueryExecutionTypeDef"]: """Return an iterator of `query_execution_info`s run by the workgroup in Athena.""" client_athena = _utils.client(service_name="athena", session=boto3_session) page_size = 50 @@ -160,7 +159,7 @@ def _get_last_query_infos( QueryExecutionIds=uncached_ids[i : i + page_size], ).get("QueryExecutions") ) - _cache_manager.update_cache(new_execution_data) # type: ignore[arg-type] + _cache_manager.update_cache(new_execution_data) return _cache_manager.sorted_successful_generator() @@ -168,6 +167,7 @@ def _check_for_cached_results( sql: str, boto3_session: boto3.Session | None, workgroup: str | None, + params: list[str] | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, ) -> _CacheInfo: """ @@ -207,7 +207,12 @@ def _check_for_cached_results( if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"): parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"]) if parsed_query is not None: - if _compare_query_string(sql=comparable_sql, other=parsed_query): + if _compare_query_string( + sql=comparable_sql, + other=parsed_query, + sql_params=params, + other_params=query_info.get("ExecutionParameters"), + ): return _CacheInfo( has_valid_cache=True, file_format="parquet", @@ -215,7 +220,12 @@ def _check_for_cached_results( query_execution_payload=query_info, ) elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"): - if _compare_query_string(sql=comparable_sql, other=query_info["Query"]): + if _compare_query_string( + sql=comparable_sql, + other=query_info["Query"], + sql_params=params, + other_params=query_info.get("ExecutionParameters"), + ): return _CacheInfo( has_valid_cache=True, file_format="csv", diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 3c97bc08d..fb4baac5d 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1048,6 +1048,7 @@ def read_sql_query( if not client_request_token: cache_info: _CacheInfo = _check_for_cached_results( sql=sql, + params=params if paramstyle == "qmark" else None, boto3_session=boto3_session, workgroup=workgroup, athena_cache_settings=athena_cache_settings, diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 6d71c10e3..6b6e54da3 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -35,6 +35,7 @@ from ._cache import _cache_manager, _LocalMetadataCacheManager if TYPE_CHECKING: + from mypy_boto3_athena.type_defs import QueryExecutionTypeDef from mypy_boto3_glue.type_defs import ColumnOutputTypeDef _QUERY_FINAL_STATES: list[str] = ["FAILED", "SUCCEEDED", "CANCELLED"] @@ -53,7 +54,7 @@ class _QueryMetadata(NamedTuple): binaries: list[str] output_location: str | None manifest_location: str | None - raw_payload: dict[str, Any] + raw_payload: "QueryExecutionTypeDef" class _WorkGroupConfig(NamedTuple): @@ -214,7 +215,7 @@ def _get_query_metadata( query_execution_id: str, boto3_session: boto3.Session | None = None, categories: list[str] | None = None, - query_execution_payload: dict[str, Any] | None = None, + query_execution_payload: "QueryExecutionTypeDef" | None = None, metadata_cache_manager: _LocalMetadataCacheManager | None = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, execution_params: list[str] | None = None, @@ -225,12 +226,15 @@ def _get_query_metadata( if query_execution_payload["Status"]["State"] != "SUCCEEDED": reason: str = query_execution_payload["Status"]["StateChangeReason"] raise exceptions.QueryFailed(f"Query error: {reason}") - _query_execution_payload: dict[str, Any] = query_execution_payload + _query_execution_payload = query_execution_payload else: - _query_execution_payload = _executions.wait_query( - query_execution_id=query_execution_id, - boto3_session=boto3_session, - athena_query_wait_polling_delay=athena_query_wait_polling_delay, + _query_execution_payload = cast( + "QueryExecutionTypeDef", + _executions.wait_query( + query_execution_id=query_execution_id, + boto3_session=boto3_session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ), ) cols_types: dict[str, str] = get_query_columns_types( query_execution_id=query_execution_id, boto3_session=boto3_session @@ -266,8 +270,8 @@ def _get_query_metadata( if "ResultConfiguration" in _query_execution_payload: output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation") - athena_statistics: dict[str, int | str] = _query_execution_payload.get("Statistics", {}) - manifest_location: str | None = str(athena_statistics.get("DataManifestLocation")) + athena_statistics = _query_execution_payload.get("Statistics", {}) + manifest_location: str | None = athena_statistics.get("DataManifestLocation") if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager: metadata_cache_manager.update_cache(items=[_query_execution_payload]) diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index cc0a5029f..6faae61be 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -461,6 +461,57 @@ def test_athena_paramstyle_qmark_parameters( assert len(df_out) == 1 +def test_athena_paramstyle_qmark_with_caching( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + ctas_approach: bool, + unload_approach: bool, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + df_out = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + database=glue_database, + ctas_approach=ctas_approach, + unload_approach=unload_approach, + workgroup=workgroup0, + params=["Washington"], + paramstyle="qmark", + keep_files=False, + s3_output=path2, + athena_cache_settings={"max_cache_seconds": 300} + ) + + assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington" + + df_out = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + database=glue_database, + ctas_approach=ctas_approach, + unload_approach=unload_approach, + workgroup=workgroup0, + params=["Seattle"], + paramstyle="qmark", + keep_files=False, + s3_output=path2, + athena_cache_settings={"max_cache_seconds": 300} + ) + + assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle" + + def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0): wr.s3.to_parquet( df=get_df(), From a554c47947d19d9edf26280852ccedc4c4efc86b Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 9 Sep 2024 17:34:58 -0500 Subject: [PATCH 2/4] fix formatting --- tests/unit/test_athena.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index 6faae61be..b0e8c49cc 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -491,7 +491,7 @@ def test_athena_paramstyle_qmark_with_caching( paramstyle="qmark", keep_files=False, s3_output=path2, - athena_cache_settings={"max_cache_seconds": 300} + athena_cache_settings={"max_cache_seconds": 300}, ) assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington" @@ -506,7 +506,7 @@ def test_athena_paramstyle_qmark_with_caching( paramstyle="qmark", keep_files=False, s3_output=path2, - athena_cache_settings={"max_cache_seconds": 300} + athena_cache_settings={"max_cache_seconds": 300}, ) assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle" From aebfddfffe8b6237758c7b5babf99274c6ce279d Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 10 Sep 2024 08:46:08 -0500 Subject: [PATCH 3/4] add missing test params --- tests/unit/test_athena.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index b0e8c49cc..890cbc702 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -461,6 +461,14 @@ def test_athena_paramstyle_qmark_parameters( assert len(df_out) == 1 +@pytest.mark.parametrize( + "ctas_approach,unload_approach", + [ + pytest.param(False, False, id="regular"), + pytest.param(True, False, id="ctas"), + pytest.param(False, True, id="unload"), + ], +) def test_athena_paramstyle_qmark_with_caching( path: str, path2: str, From 7d8101393121ded1039c2b0719b990912cc7cede Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 9 Oct 2024 13:10:17 -0500 Subject: [PATCH 4/4] remove caching for qmark --- awswrangler/athena/_cache.py | 21 ++++----------------- awswrangler/athena/_read.py | 6 ++++-- tests/unit/test_athena.py | 2 +- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/awswrangler/athena/_cache.py b/awswrangler/athena/_cache.py index 44425fefa..2e9252baf 100644 --- a/awswrangler/athena/_cache.py +++ b/awswrangler/athena/_cache.py @@ -110,13 +110,11 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None: return None -def _compare_query_string( - sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None -) -> bool: +def _compare_query_string(sql: str, other: str) -> bool: comparison_query = _prepare_query_string_for_comparison(query_string=other) _logger.debug("sql: %s", sql) _logger.debug("comparison_query: %s", comparison_query) - return sql == comparison_query and sql_params == other_params + return sql == comparison_query def _prepare_query_string_for_comparison(query_string: str) -> str: @@ -167,7 +165,6 @@ def _check_for_cached_results( sql: str, boto3_session: boto3.Session | None, workgroup: str | None, - params: list[str] | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, ) -> _CacheInfo: """ @@ -207,12 +204,7 @@ def _check_for_cached_results( if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"): parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"]) if parsed_query is not None: - if _compare_query_string( - sql=comparable_sql, - other=parsed_query, - sql_params=params, - other_params=query_info.get("ExecutionParameters"), - ): + if _compare_query_string(sql=comparable_sql, other=parsed_query): return _CacheInfo( has_valid_cache=True, file_format="parquet", @@ -220,12 +212,7 @@ def _check_for_cached_results( query_execution_payload=query_info, ) elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"): - if _compare_query_string( - sql=comparable_sql, - other=query_info["Query"], - sql_params=params, - other_params=query_info.get("ExecutionParameters"), - ): + if _compare_query_string(sql=comparable_sql, other=query_info["Query"]): return _CacheInfo( has_valid_cache=True, file_format="csv", diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 546b5aaea..ef81a34c9 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1045,10 +1045,12 @@ def read_sql_query( # Substitute query parameters if applicable sql, execution_params = _apply_formatter(sql, params, paramstyle) - if not client_request_token: + if not client_request_token and paramstyle != "qmark": + # For paramstyle=="qmark", we will need to use Athena's caching option. + # The issue is that when describing an Athena execution, the API does not return + # the parameters that were used. cache_info: _CacheInfo = _check_for_cached_results( sql=sql, - params=params if paramstyle == "qmark" else None, boto3_session=boto3_session, workgroup=workgroup, athena_cache_settings=athena_cache_settings, diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index 890cbc702..d747ae001 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -469,7 +469,7 @@ def test_athena_paramstyle_qmark_parameters( pytest.param(False, True, id="unload"), ], ) -def test_athena_paramstyle_qmark_with_caching( +def test_athena_paramstyle_qmark_skip_caching( path: str, path2: str, glue_database: str,