diff --git a/awswrangler/athena/_spark.py b/awswrangler/athena/_spark.py index 40cdecf2e..10ef03097 100644 --- a/awswrangler/athena/_spark.py +++ b/awswrangler/athena/_spark.py @@ -91,6 +91,7 @@ def create_spark_session( max_concurrent_dpus: int = 5, default_executor_dpu_size: int = 1, additional_configs: Optional[Dict[str, Any]] = None, + spark_properties: Optional[Dict[str, Any]] = None, idle_timeout: int = 15, boto3_session: Optional[boto3.Session] = None, ) -> str: @@ -110,6 +111,9 @@ def create_spark_session( The default number of DPUs to use for executors. The default is 1. additional_configs : Dict[str, Any], optional Contains additional engine parameter mappings in the form of key-value pairs. + spark_properties: Dict[str, Any], optional + Contains SparkProperties in the form of key-value pairs.Specifies custom jar files and Spark properties + for use cases like cluster encryption, table formats, and general Spark tuning. idle_timeout : int, optional The idle timeout in minutes for the session. The default is 15. boto3_session : boto3.Session(), optional @@ -134,6 +138,8 @@ def create_spark_session( } if additional_configs: engine_configuration["AdditionalConfigs"] = additional_configs + if spark_properties: + engine_configuration["SparkProperties"] = spark_properties response = client_athena.start_session( WorkGroup=workgroup, EngineConfiguration=engine_configuration, @@ -157,6 +163,7 @@ def run_spark_calculation( max_concurrent_dpus: int = 5, default_executor_dpu_size: int = 1, additional_configs: Optional[Dict[str, Any]] = None, + spark_properties: Optional[Dict[str, Any]] = None, idle_timeout: int = 15, boto3_session: Optional[boto3.Session] = None, ) -> Dict[str, Any]: @@ -180,6 +187,9 @@ def run_spark_calculation( The default number of DPUs to use for executors. The default is 1. additional_configs : Dict[str, Any], optional Contains additional engine parameter mappings in the form of key-value pairs. + spark_properties: Dict[str, Any], optional + Contains SparkProperties in the form of key-value pairs.Specifies custom jar files and Spark properties + for use cases like cluster encryption, table formats, and general Spark tuning. idle_timeout : int, optional The idle timeout in minutes for the session. The default is 15. boto3_session : boto3.Session(), optional @@ -208,6 +218,7 @@ def run_spark_calculation( max_concurrent_dpus=max_concurrent_dpus, default_executor_dpu_size=default_executor_dpu_size, additional_configs=additional_configs, + spark_properties=spark_properties, idle_timeout=idle_timeout, boto3_session=boto3_session, ) diff --git a/tests/unit/test_athena_spark.py b/tests/unit/test_athena_spark.py index 5de8b9200..e1d94178d 100644 --- a/tests/unit/test_athena_spark.py +++ b/tests/unit/test_athena_spark.py @@ -47,3 +47,28 @@ def test_athena_spark_calculation(code, path, workgroup_spark): ) assert result["Status"]["State"] == "COMPLETED" + + +@pytest.mark.parametrize( + "code", + [ + """ +output_path = "$PATH" + +data = spark.range(0, 5) +data.write.format("delta").save(output_path) + """, + ], +) +def test_athena_spark_calculation_with_spark_properties(code, path, workgroup_spark): + code = code.replace("$PATH", path) + + result = wr.athena.run_spark_calculation( + code=code, + workgroup=workgroup_spark, + spark_properties={ + "spark.sql.catalog.spark_catalog": "org.apache.spark.sql.delta.catalog.DeltaCatalog", + "spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension", + }, + ) + assert result["Status"]["State"] == "COMPLETED"