Skip to content

Commit

Permalink
feat: add analysis_template_arn to cleanrooms read_sql_query (#2584)
Browse files Browse the repository at this point in the history
* feat: add analysis_template_arn to cleanrooms read_sql_query

* fix: PR feedback on typing and docs
  • Loading branch information
jaidisido authored Jan 3, 2024
1 parent f6b71a9 commit a34607d
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 11 deletions.
54 changes: 44 additions & 10 deletions awswrangler/cleanrooms/_read.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Amazon Clean Rooms Module hosting read_* functions."""

import logging
from typing import Any, Dict, Iterator, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Union

import boto3

import awswrangler.pandas as pd
from awswrangler import _utils, s3
from awswrangler import _utils, exceptions, s3
from awswrangler._sql_formatter import _process_sql_params
from awswrangler.cleanrooms._utils import wait_query

if TYPE_CHECKING:
from mypy_boto3_cleanrooms.type_defs import ProtectedQuerySQLParametersTypeDef

_logger: logging.Logger = logging.getLogger(__name__)


Expand All @@ -23,10 +26,11 @@ def _delete_after_iterate(


def read_sql_query(
sql: str,
membership_id: str,
output_bucket: str,
output_prefix: str,
sql: Optional[str] = None,
analysis_template_arn: Optional[str] = None,
membership_id: str = "",
output_bucket: str = "",
output_prefix: str = "",
keep_files: bool = True,
params: Optional[Dict[str, Any]] = None,
chunksize: Optional[Union[int, bool]] = None,
Expand All @@ -36,10 +40,16 @@ def read_sql_query(
) -> Union[Iterator[pd.DataFrame], pd.DataFrame]:
"""Execute Clean Rooms Protected SQL query and return the results as a Pandas DataFrame.
Note
----
One of `sql` or `analysis_template_arn` must be supplied, not both.
Parameters
----------
sql : str
sql : str, optional
SQL query
analysis_template_arn: str, optional
ARN of the analysis template
membership_id : str
Membership ID
output_bucket : str
Expand All @@ -49,9 +59,13 @@ def read_sql_query(
keep_files : bool, optional
Whether files in S3 output bucket/prefix are retained. 'True' by default
params : Dict[str, any], optional
Dict of parameters used for constructing the SQL query. Only named parameters are supported.
(Client-side) If used in combination with the `sql` parameter, it's the Dict of parameters used
for constructing the SQL query. Only named parameters are supported.
The dict must be in the form {'name': 'value'} and the SQL query must contain
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
(Server-side) If used in combination with the `analysis_template_arn` parameter, it's the Dict of parameters
supplied with the analysis template. It must be a string to string dict in the form {'name': 'value'}.
chunksize : Union[int, bool], optional
If passed, the data is split into an iterable of DataFrames (Memory friendly).
If `True` an iterable of DataFrames is returned without guarantee of chunksize.
Expand Down Expand Up @@ -82,13 +96,33 @@ def read_sql_query(
>>> output_bucket='output-bucket',
>>> output_prefix='output-prefix',
>>> )
>>> import awswrangler as wr
>>> df = wr.cleanrooms.read_sql_query(
>>> analysis_template_arn='arn:aws:cleanrooms:...',
>>> params={'param1': 'value1'},
>>> membership_id='membership-id',
>>> output_bucket='output-bucket',
>>> output_prefix='output-prefix',
>>> )
"""
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)

if sql:
sql_parameters: "ProtectedQuerySQLParametersTypeDef" = {
"queryString": _process_sql_params(sql, params, engine_type="partiql")
}
elif analysis_template_arn:
sql_parameters = {"analysisTemplateArn": analysis_template_arn}
if params:
sql_parameters["parameters"] = params
else:
raise exceptions.InvalidArgumentCombination("One of `sql` or `analysis_template_arn` must be supplied")

query_id: str = client_cleanrooms.start_protected_query(
type="SQL",
membershipIdentifier=membership_id,
sqlParameters={"queryString": _process_sql_params(sql, params, engine_type="partiql")},
sqlParameters=sql_parameters,
resultConfiguration={
"outputConfiguration": {
"s3": {
Expand Down
73 changes: 73 additions & 0 deletions test_infra/stacks/cleanrooms_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ def __init__(
),
)

self.custom_table = glue.Table(
self,
"Custom Table",
database=self.database,
table_name="custom",
columns=[
glue.Column(name="a", type=glue.Type(input_string="int", is_primitive=True)),
glue.Column(name="b", type=glue.Type(input_string="string", is_primitive=True)),
],
bucket=self.bucket,
s3_prefix="custom",
data_format=glue.DataFormat(
input_format=glue.InputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"),
output_format=glue.OutputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"),
serialization_library=glue.SerializationLibrary(
"org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
),
),
)

self.users_configured_table = cleanrooms.CfnConfiguredTable(
self,
"Users Configured Table",
Expand Down Expand Up @@ -211,6 +231,49 @@ def __init__(
],
)

self.analysis_template = cleanrooms.CfnAnalysisTemplate(
self,
"AnalysisTemplate",
format="SQL",
membership_identifier=self.membership.attr_membership_identifier,
name="custom_analysis",
source=cleanrooms.CfnAnalysisTemplate.AnalysisSourceProperty(
text="SELECT a FROM custom WHERE custom.b = :param1"
),
analysis_parameters=[
cleanrooms.CfnAnalysisTemplate.AnalysisParameterProperty(
name="param1",
type="VARCHAR",
)
],
)

self.custom_configured_table = cleanrooms.CfnConfiguredTable(
self,
"Custom Configured Table",
allowed_columns=["a", "b"],
analysis_method="DIRECT_QUERY",
name="custom",
table_reference=cleanrooms.CfnConfiguredTable.TableReferenceProperty(
glue=cleanrooms.CfnConfiguredTable.GlueTableReferenceProperty(
database_name=self.database.database_name,
table_name=self.custom_table.table_name,
)
),
analysis_rules=[
cleanrooms.CfnConfiguredTable.AnalysisRuleProperty(
policy=cleanrooms.CfnConfiguredTable.ConfiguredTableAnalysisRulePolicyProperty(
v1=cleanrooms.CfnConfiguredTable.ConfiguredTableAnalysisRulePolicyV1Property(
custom=cleanrooms.CfnConfiguredTable.AnalysisRuleCustomProperty(
allowed_analyses=[self.analysis_template.attr_arn],
),
)
),
type="CUSTOM",
)
],
)

self.users_configured_table_association = cleanrooms.CfnConfiguredTableAssociation(
self,
"Users Configured Table Association",
Expand All @@ -229,7 +292,17 @@ def __init__(
role_arn=self.cleanrooms_service_role.role_arn,
)

self.custom_configured_table_association = cleanrooms.CfnConfiguredTableAssociation(
self,
"Custom Configured Table Association",
configured_table_identifier=self.custom_configured_table.attr_configured_table_identifier,
membership_identifier=self.membership.attr_membership_identifier,
name="custom",
role_arn=self.cleanrooms_service_role.role_arn,
)

CfnOutput(self, "CleanRoomsMembershipId", value=self.membership.attr_membership_identifier)
CfnOutput(self, "CleanRoomsAnalysisTemplateArn", value=self.analysis_template.attr_arn)
CfnOutput(self, "CleanRoomsGlueDatabaseName", value=self.database.database_name)
CfnOutput(self, "CleanRoomsS3BucketName", value=self.bucket.bucket_name)

Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,11 @@ def cleanrooms_membership_id(cloudformation_outputs):
return cloudformation_outputs["CleanRoomsMembershipId"]


@pytest.fixture(scope="session")
def cleanrooms_analysis_template_arn(cloudformation_outputs):
return cloudformation_outputs["CleanRoomsAnalysisTemplateArn"]


@pytest.fixture(scope="session")
def cleanrooms_glue_database_name(cloudformation_outputs):
return cloudformation_outputs["CleanRoomsGlueDatabaseName"]
Expand Down
32 changes: 31 additions & 1 deletion tests/unit/test_cleanrooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,28 @@ def data(cleanrooms_s3_bucket_name: str, cleanrooms_glue_database_name: str) ->
mode="overwrite",
)

df_custom = pd.DataFrame(
{
"a": list(range(1, 9)),
"b": ["A", "A", "B", "C", "C", "C", "D", "E"],
}
)
wr.s3.to_parquet(
df_custom,
f"s3://{cleanrooms_s3_bucket_name}/custom/",
dataset=True,
database=cleanrooms_glue_database_name,
table="custom",
mode="overwrite",
)

def test_read_sql_query(data: None, cleanrooms_membership_id: str, cleanrooms_s3_bucket_name: str):

def test_read_sql_query(
data: None,
cleanrooms_membership_id: str,
cleanrooms_analysis_template_arn: str,
cleanrooms_s3_bucket_name: str,
):
sql = """SELECT city, AVG(p.sale_value)
FROM users u
INNER JOIN purchases p ON u.user_id = p.user_id
Expand Down Expand Up @@ -71,3 +91,13 @@ def test_read_sql_query(data: None, cleanrooms_membership_id: str, cleanrooms_s3
keep_files=False,
)
assert df.shape == (2, 3)

df = wr.cleanrooms.read_sql_query(
analysis_template_arn=cleanrooms_analysis_template_arn,
params={"param1": "C"},
membership_id=cleanrooms_membership_id,
output_bucket=cleanrooms_s3_bucket_name,
output_prefix="results",
keep_files=False,
)
assert df.shape == (3, 1)

0 comments on commit a34607d

Please sign in to comment.