Skip to content

Commit

Permalink
fix: Athena/Neptune minor fixes (#2526)
Browse files Browse the repository at this point in the history
* fix: Raise EmptyDataFrame exception in wr.neptune.to_property_graph

* fix: Explicitly use primary Athena workgroup
  • Loading branch information
kukushking authored Nov 21, 2023
1 parent d1f525c commit 28880d5
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 54 deletions.
6 changes: 3 additions & 3 deletions awswrangler/athena/_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def start_query_execution(
sql: str,
database: Optional[str] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
params: Union[Dict[str, Any], List[str], None] = None,
Expand All @@ -62,8 +62,8 @@ def start_query_execution(
AWS Glue/Athena database name.
s3_output : str, optional
AWS S3 path.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
kms_key : str, optional
Expand Down
6 changes: 3 additions & 3 deletions awswrangler/athena/_executions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def start_query_execution(
sql: str,
database: Optional[str] = ...,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
params: Union[Dict[str, Any], List[str], None] = ...,
Expand All @@ -34,7 +34,7 @@ def start_query_execution(
*,
database: Optional[str] = ...,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
params: Union[Dict[str, Any], List[str], None] = ...,
Expand All @@ -51,7 +51,7 @@ def start_query_execution(
*,
database: Optional[str] = ...,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
params: Union[Dict[str, Any], List[str], None] = ...,
Expand Down
18 changes: 9 additions & 9 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
categories: Optional[List[str]] = None,
chunksize: Optional[Union[int, bool]] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
keep_files: bool = True,
Expand Down Expand Up @@ -929,8 +929,8 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
If an `INTEGER` is passed awswrangler will iterate on the data by number of rows equal the received INTEGER.
s3_output : str, optional
Amazon S3 path.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
kms_key : str, optional
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def read_sql_table(
categories: Optional[List[str]] = None,
chunksize: Optional[Union[int, bool]] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
keep_files: bool = True,
Expand Down Expand Up @@ -1272,8 +1272,8 @@ def read_sql_table(
If an `INTEGER` is passed awswrangler will iterate on the data by number of rows equal the received INTEGER.
s3_output : str, optional
AWS S3 path.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
kms_key : str, optional
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def unload(
compression: Optional[str] = None,
field_delimiter: Optional[str] = None,
partitioned_by: Optional[List[str]] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
boto3_session: Optional[boto3.Session] = None,
Expand Down Expand Up @@ -1397,8 +1397,8 @@ def unload(
A single-character field delimiter for files in CSV, TSV, and other text formats.
partitioned_by : Optional[List[str]]
An array list of columns by which the output is partitioned.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
kms_key : str, optional
Expand Down
22 changes: 11 additions & 11 deletions awswrangler/athena/_read.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def read_sql_query( # pylint: disable=too-many-arguments
categories: Optional[List[str]] = ...,
chunksize: Union[None, Literal[False]] = ...,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -91,7 +91,7 @@ def read_sql_query(
categories: Optional[List[str]] = ...,
chunksize: Literal[True],
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -118,7 +118,7 @@ def read_sql_query(
categories: Optional[List[str]] = ...,
chunksize: bool,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -145,7 +145,7 @@ def read_sql_query(
categories: Optional[List[str]] = ...,
chunksize: int,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -172,7 +172,7 @@ def read_sql_query(
categories: Optional[List[str]] = ...,
chunksize: Optional[Union[int, bool]],
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -199,7 +199,7 @@ def read_sql_table(
categories: Optional[List[str]] = ...,
chunksize: Union[None, Literal[False]] = ...,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -223,7 +223,7 @@ def read_sql_table(
categories: Optional[List[str]] = ...,
chunksize: Literal[True],
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -247,7 +247,7 @@ def read_sql_table(
categories: Optional[List[str]] = ...,
chunksize: bool,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -271,7 +271,7 @@ def read_sql_table(
categories: Optional[List[str]] = ...,
chunksize: int,
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -295,7 +295,7 @@ def read_sql_table(
categories: Optional[List[str]] = ...,
chunksize: Optional[Union[int, bool]],
s3_output: Optional[str] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
keep_files: bool = ...,
Expand All @@ -315,7 +315,7 @@ def unload(
compression: Optional[str] = ...,
field_delimiter: Optional[str] = ...,
partitioned_by: Optional[List[str]] = ...,
workgroup: Optional[str] = ...,
workgroup: str = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
boto3_session: Optional[boto3.Session] = ...,
Expand Down
20 changes: 8 additions & 12 deletions awswrangler/athena/_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _does_statement_exist(
def create_prepared_statement(
sql: str,
statement_name: str,
workgroup: Optional[str] = None,
workgroup: str = "primary",
mode: Literal["update", "error"] = "update",
boto3_session: Optional[boto3.Session] = None,
) -> None:
Expand All @@ -50,8 +50,8 @@ def create_prepared_statement(
The query string for the prepared statement.
statement_name : str
The name of the prepared statement.
workgroup : str, optional
The name of the workgroup to which the prepared statement belongs.
workgroup : str
The name of the workgroup to which the prepared statement belongs. Primary by default.
mode: str
Determines the behaviour if the prepared statement already exists:
Expand All @@ -72,7 +72,6 @@ def create_prepared_statement(
raise exceptions.InvalidArgumentValue("`mode` must be one of 'update' or 'error'.")

athena_client = _utils.client("athena", session=boto3_session)
workgroup = workgroup if workgroup else "primary"

already_exists = _does_statement_exist(statement_name, workgroup, athena_client)
if already_exists and mode == "error":
Expand All @@ -95,16 +94,14 @@ def create_prepared_statement(


@apply_configs
def list_prepared_statements(
workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
) -> List[str]:
def list_prepared_statements(workgroup: str = "primary", boto3_session: Optional[boto3.Session] = None) -> List[str]:
"""
List the prepared statements in the specified workgroup.
Parameters
----------
workgroup: str, optional
The name of the workgroup to which the prepared statement belongs.
workgroup: str
The name of the workgroup to which the prepared statement belongs. Primary by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
Expand All @@ -115,7 +112,6 @@ def list_prepared_statements(
Each item is a dictionary with the keys ``StatementName`` and ``LastModifiedTime``.
"""
athena_client = _utils.client("athena", session=boto3_session)
workgroup = workgroup if workgroup else "primary"

response = athena_client.list_prepared_statements(WorkGroup=workgroup)
statements = response["PreparedStatements"]
Expand All @@ -130,7 +126,7 @@ def list_prepared_statements(
@apply_configs
def delete_prepared_statement(
statement_name: str,
workgroup: Optional[str] = None,
workgroup: str = "primary",
boto3_session: Optional[boto3.Session] = None,
) -> None:
"""
Expand All @@ -143,7 +139,7 @@ def delete_prepared_statement(
statement_name : str
The name of the prepared statement.
workgroup : str, optional
The name of the workgroup to which the prepared statement belongs.
The name of the workgroup to which the prepared statement belongs. Primary by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
Expand Down
26 changes: 13 additions & 13 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _start_query_execution(
return response["QueryExecutionId"]


def _get_workgroup_config(session: Optional[boto3.Session] = None, workgroup: Optional[str] = None) -> _WorkGroupConfig:
def _get_workgroup_config(session: Optional[boto3.Session] = None, workgroup: str = "primary") -> _WorkGroupConfig:
enforced: bool
wg_s3_output: Optional[str]
wg_encryption: Optional[str]
Expand Down Expand Up @@ -472,7 +472,7 @@ def repair_table(
database: Optional[str] = None,
data_source: Optional[str] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
Expand Down Expand Up @@ -501,8 +501,8 @@ def repair_table(
Data Source / Catalog name. If None, 'AwsDataCatalog' is used.
s3_output : str, optional
AWS S3 path.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
kms_key : str, optional
Expand Down Expand Up @@ -552,7 +552,7 @@ def describe_table(
table: str,
database: Optional[str] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
Expand All @@ -577,8 +577,8 @@ def describe_table(
AWS Glue/Athena database name.
s3_output : str, optional
AWS S3 path.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
kms_key : str, optional
Expand Down Expand Up @@ -641,7 +641,7 @@ def create_ctas_table( # pylint: disable=too-many-locals
bucketing_info: Optional[typing.BucketingInfoTuple] = None,
field_delimiter: Optional[str] = None,
schema_only: bool = False,
workgroup: Optional[str] = None,
workgroup: str = "primary",
data_source: Optional[str] = None,
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
Expand Down Expand Up @@ -686,8 +686,8 @@ def create_ctas_table( # pylint: disable=too-many-locals
The single-character field delimiter for files in CSV, TSV, and text files.
schema_only : bool, optional
_description_, by default False
workgroup : Optional[str], optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
data_source : Optional[str], optional
Data Source / Catalog name. If None, 'AwsDataCatalog' is used.
encryption : str, optional
Expand Down Expand Up @@ -856,7 +856,7 @@ def show_create_table(
table: str,
database: Optional[str] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
Expand All @@ -880,8 +880,8 @@ def show_create_table(
AWS Glue/Athena database name.
s3_output : str, optional
AWS S3 path.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
kms_key : str, optional
Expand Down
6 changes: 3 additions & 3 deletions awswrangler/athena/_write_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def to_iceberg(
partition_cols: Optional[List[str]] = None,
keep_files: bool = True,
data_source: Optional[str] = None,
workgroup: Optional[str] = None,
workgroup: str = "primary",
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
boto3_session: Optional[boto3.Session] = None,
Expand Down Expand Up @@ -237,8 +237,8 @@ def to_iceberg(
Whether staging files produced by Athena are retained. 'True' by default.
data_source : str, optional
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
workgroup : str, optional
Athena workgroup.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
kms_key : str, optional
Expand Down
2 changes: 2 additions & 0 deletions awswrangler/neptune/_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def to_property_graph(
raise exceptions.InvalidArgumentValue(
"DataFrame must contain at least a ~id and a ~label column to be saved to Amazon Neptune"
)
if df.empty:
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")

# Loop through items in the DF
for index, row in df.iterrows():
Expand Down
Loading

0 comments on commit 28880d5

Please sign in to comment.