Skip to content

Commit

Permalink
remove caching for qmark
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger committed Oct 9, 2024
1 parent 85e7a48 commit 7d81013
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 20 deletions.
21 changes: 4 additions & 17 deletions awswrangler/athena/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,11 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None:
return None


def _compare_query_string(
sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None
) -> bool:
def _compare_query_string(sql: str, other: str) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
return sql == comparison_query and sql_params == other_params
return sql == comparison_query


def _prepare_query_string_for_comparison(query_string: str) -> str:
Expand Down Expand Up @@ -167,7 +165,6 @@ def _check_for_cached_results(
sql: str,
boto3_session: boto3.Session | None,
workgroup: str | None,
params: list[str] | None = None,
athena_cache_settings: typing.AthenaCacheSettings | None = None,
) -> _CacheInfo:
"""
Expand Down Expand Up @@ -207,25 +204,15 @@ def _check_for_cached_results(
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(
sql=comparable_sql,
other=parsed_query,
sql_params=params,
other_params=query_info.get("ExecutionParameters"),
):
if _compare_query_string(sql=comparable_sql, other=parsed_query):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(
sql=comparable_sql,
other=query_info["Query"],
sql_params=params,
other_params=query_info.get("ExecutionParameters"),
):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
Expand Down
6 changes: 4 additions & 2 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,12 @@ def read_sql_query(
# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)

if not client_request_token:
if not client_request_token and paramstyle != "qmark":
# For paramstyle=="qmark", we will need to use Athena's caching option.
# The issue is that when describing an Athena execution, the API does not return
# the parameters that were used.
cache_info: _CacheInfo = _check_for_cached_results(
sql=sql,
params=params if paramstyle == "qmark" else None,
boto3_session=boto3_session,
workgroup=workgroup,
athena_cache_settings=athena_cache_settings,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def test_athena_paramstyle_qmark_parameters(
pytest.param(False, True, id="unload"),
],
)
def test_athena_paramstyle_qmark_with_caching(
def test_athena_paramstyle_qmark_skip_caching(
path: str,
path2: str,
glue_database: str,
Expand Down

0 comments on commit 7d81013

Please sign in to comment.