From 7d8101393121ded1039c2b0719b990912cc7cede Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 9 Oct 2024 13:10:17 -0500 Subject: [PATCH] remove caching for qmark --- awswrangler/athena/_cache.py | 21 ++++----------------- awswrangler/athena/_read.py | 6 ++++-- tests/unit/test_athena.py | 2 +- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/awswrangler/athena/_cache.py b/awswrangler/athena/_cache.py index 44425fefa..2e9252baf 100644 --- a/awswrangler/athena/_cache.py +++ b/awswrangler/athena/_cache.py @@ -110,13 +110,11 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None: return None -def _compare_query_string( - sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None -) -> bool: +def _compare_query_string(sql: str, other: str) -> bool: comparison_query = _prepare_query_string_for_comparison(query_string=other) _logger.debug("sql: %s", sql) _logger.debug("comparison_query: %s", comparison_query) - return sql == comparison_query and sql_params == other_params + return sql == comparison_query def _prepare_query_string_for_comparison(query_string: str) -> str: @@ -167,7 +165,6 @@ def _check_for_cached_results( sql: str, boto3_session: boto3.Session | None, workgroup: str | None, - params: list[str] | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, ) -> _CacheInfo: """ @@ -207,12 +204,7 @@ def _check_for_cached_results( if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"): parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"]) if parsed_query is not None: - if _compare_query_string( - sql=comparable_sql, - other=parsed_query, - sql_params=params, - other_params=query_info.get("ExecutionParameters"), - ): + if _compare_query_string(sql=comparable_sql, other=parsed_query): return _CacheInfo( has_valid_cache=True, file_format="parquet", @@ -220,12 +212,7 @@ def _check_for_cached_results( query_execution_payload=query_info, ) elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"): - if _compare_query_string( - sql=comparable_sql, - other=query_info["Query"], - sql_params=params, - other_params=query_info.get("ExecutionParameters"), - ): + if _compare_query_string(sql=comparable_sql, other=query_info["Query"]): return _CacheInfo( has_valid_cache=True, file_format="csv", diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 546b5aaea..ef81a34c9 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1045,10 +1045,12 @@ def read_sql_query( # Substitute query parameters if applicable sql, execution_params = _apply_formatter(sql, params, paramstyle) - if not client_request_token: + if not client_request_token and paramstyle != "qmark": + # For paramstyle=="qmark", we will need to use Athena's caching option. + # The issue is that when describing an Athena execution, the API does not return + # the parameters that were used. cache_info: _CacheInfo = _check_for_cached_results( sql=sql, - params=params if paramstyle == "qmark" else None, boto3_session=boto3_session, workgroup=workgroup, athena_cache_settings=athena_cache_settings, diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index 890cbc702..d747ae001 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -469,7 +469,7 @@ def test_athena_paramstyle_qmark_parameters( pytest.param(False, True, id="unload"), ], ) -def test_athena_paramstyle_qmark_with_caching( +def test_athena_paramstyle_qmark_skip_caching( path: str, path2: str, glue_database: str,