Skip to content

Commit

Permalink
fix qmark cache issue in Athena
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger committed Sep 9, 2024
1 parent 78522fd commit 53c0c48
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 25 deletions.
42 changes: 26 additions & 16 deletions awswrangler/athena/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import threading
from heapq import heappop, heappush
from typing import TYPE_CHECKING, Any, Match, NamedTuple
from typing import TYPE_CHECKING, Match, NamedTuple

import boto3

Expand All @@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple):
has_valid_cache: bool
file_format: str | None = None
query_execution_id: str | None = None
query_execution_payload: dict[str, Any] | None = None
query_execution_payload: "QueryExecutionTypeDef" | None = None


class _LocalMetadataCacheManager:
def __init__(self) -> None:
self._lock: threading.Lock = threading.Lock()
self._cache: dict[str, Any] = {}
self._cache: dict[str, "QueryExecutionTypeDef"] = {}
self._pqueue: list[tuple[datetime.datetime, str]] = []
self._max_cache_size = 100

def update_cache(self, items: list[dict[str, Any]]) -> None:
def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None:
"""
Update the local metadata cache with new query metadata.
Parameters
----------
items : List[Dict[str, Any]]
items
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
"""
with self._lock:
Expand All @@ -62,18 +62,17 @@ def update_cache(self, items: list[dict[str, Any]]) -> None:
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
self._cache[item["QueryExecutionId"]] = item

def sorted_successful_generator(self) -> list[dict[str, Any]]:
def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]:
"""
Sorts the entries in the local cache based on query Completion DateTime.
This is useful to guarantee LRU caching rules.
Returns
-------
List[Dict[str, Any]]
Returns successful DDL and DML queries sorted by query completion time.
"""
filtered: list[dict[str, Any]] = []
filtered: list["QueryExecutionTypeDef"] = []
for query in self._cache.values():
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
Expand Down Expand Up @@ -111,13 +110,13 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None:
return None


def _compare_query_string(sql: str, other: str) -> bool:
def _compare_query_string(
sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None
) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
if sql == comparison_query:
return True
return False
return sql == comparison_query and sql_params == other_params


def _prepare_query_string_for_comparison(query_string: str) -> str:
Expand All @@ -135,7 +134,7 @@ def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: boto3.Session | None = None,
workgroup: str | None = None,
) -> list[dict[str, Any]]:
) -> list["QueryExecutionTypeDef"]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena = _utils.client(service_name="athena", session=boto3_session)
page_size = 50
Expand All @@ -160,14 +159,15 @@ def _get_last_query_infos(
QueryExecutionIds=uncached_ids[i : i + page_size],
).get("QueryExecutions")
)
_cache_manager.update_cache(new_execution_data) # type: ignore[arg-type]
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()


def _check_for_cached_results(
sql: str,
boto3_session: boto3.Session | None,
workgroup: str | None,
params: list[str] | None = None,
athena_cache_settings: typing.AthenaCacheSettings | None = None,
) -> _CacheInfo:
"""
Expand Down Expand Up @@ -207,15 +207,25 @@ def _check_for_cached_results(
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
if _compare_query_string(
sql=comparable_sql,
other=parsed_query,
sql_params=params,
other_params=query_info.get("ExecutionParameters"),
):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
if _compare_query_string(
sql=comparable_sql,
other=query_info["Query"],
sql_params=params,
other_params=query_info.get("ExecutionParameters"),
):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
Expand Down
1 change: 1 addition & 0 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ def read_sql_query(
if not client_request_token:
cache_info: _CacheInfo = _check_for_cached_results(
sql=sql,
params=params if paramstyle == "qmark" else None,
boto3_session=boto3_session,
workgroup=workgroup,
athena_cache_settings=athena_cache_settings,
Expand Down
22 changes: 13 additions & 9 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._cache import _cache_manager, _LocalMetadataCacheManager

if TYPE_CHECKING:
from mypy_boto3_athena.type_defs import QueryExecutionTypeDef
from mypy_boto3_glue.type_defs import ColumnOutputTypeDef

_QUERY_FINAL_STATES: list[str] = ["FAILED", "SUCCEEDED", "CANCELLED"]
Expand All @@ -53,7 +54,7 @@ class _QueryMetadata(NamedTuple):
binaries: list[str]
output_location: str | None
manifest_location: str | None
raw_payload: dict[str, Any]
raw_payload: "QueryExecutionTypeDef"


class _WorkGroupConfig(NamedTuple):
Expand Down Expand Up @@ -214,7 +215,7 @@ def _get_query_metadata(
query_execution_id: str,
boto3_session: boto3.Session | None = None,
categories: list[str] | None = None,
query_execution_payload: dict[str, Any] | None = None,
query_execution_payload: "QueryExecutionTypeDef" | None = None,
metadata_cache_manager: _LocalMetadataCacheManager | None = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
execution_params: list[str] | None = None,
Expand All @@ -225,12 +226,15 @@ def _get_query_metadata(
if query_execution_payload["Status"]["State"] != "SUCCEEDED":
reason: str = query_execution_payload["Status"]["StateChangeReason"]
raise exceptions.QueryFailed(f"Query error: {reason}")
_query_execution_payload: dict[str, Any] = query_execution_payload
_query_execution_payload = query_execution_payload
else:
_query_execution_payload = _executions.wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
_query_execution_payload = cast(
"QueryExecutionTypeDef",
_executions.wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
),
)
cols_types: dict[str, str] = get_query_columns_types(
query_execution_id=query_execution_id, boto3_session=boto3_session
Expand Down Expand Up @@ -266,8 +270,8 @@ def _get_query_metadata(
if "ResultConfiguration" in _query_execution_payload:
output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation")

athena_statistics: dict[str, int | str] = _query_execution_payload.get("Statistics", {})
manifest_location: str | None = str(athena_statistics.get("DataManifestLocation"))
athena_statistics = _query_execution_payload.get("Statistics", {})
manifest_location: str | None = athena_statistics.get("DataManifestLocation")

if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager:
metadata_cache_manager.update_cache(items=[_query_execution_payload])
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,57 @@ def test_athena_paramstyle_qmark_parameters(
assert len(df_out) == 1


def test_athena_paramstyle_qmark_with_caching(
path: str,
path2: str,
glue_database: str,
glue_table: str,
workgroup0: str,
ctas_approach: bool,
unload_approach: bool,
) -> None:
wr.s3.to_parquet(
df=get_df(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

df_out = wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
database=glue_database,
ctas_approach=ctas_approach,
unload_approach=unload_approach,
workgroup=workgroup0,
params=["Washington"],
paramstyle="qmark",
keep_files=False,
s3_output=path2,
athena_cache_settings={"max_cache_seconds": 300}
)

assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington"

df_out = wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
database=glue_database,
ctas_approach=ctas_approach,
unload_approach=unload_approach,
workgroup=workgroup0,
params=["Seattle"],
paramstyle="qmark",
keep_files=False,
s3_output=path2,
athena_cache_settings={"max_cache_seconds": 300}
)

assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle"


def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0):
wr.s3.to_parquet(
df=get_df(),
Expand Down

0 comments on commit 53c0c48

Please sign in to comment.