Skip to content

Commit

Permalink
fix: remove part of openlineage extraction from S3ToRedshiftOperator (a…
Browse files Browse the repository at this point in the history
…pache#41631)

Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda authored Aug 21, 2024
1 parent 2727d5d commit 9af2636
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 112 deletions.
43 changes: 8 additions & 35 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class S3ToRedshiftOperator(BaseOperator):
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param column_list: list of column names to load
:param column_list: list of column names to load source data fields into specific target columns
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-column-mapping.html#copy-column-list
:param copy_options: reference to a list of COPY options
:param method: Action to be performed on execution. Available ``APPEND``, ``UPSERT`` and ``REPLACE``.
:param upsert_keys: List of fields to use as key on upsert action
Expand Down Expand Up @@ -204,18 +205,13 @@ def execute(self, context: Context) -> None:

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will query destination table."""
from pathlib import Path

from airflow.providers.amazon.aws.utils.openlineage import (
get_facets_from_redshift_table,
get_identity_column_lineage_facet,
)
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Identifier,
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
SymlinksDatasetFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

Expand All @@ -235,36 +231,8 @@ def get_openlineage_facets_on_complete(self, task_instance):
database = redshift_sql_hook.conn.schema
authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority
output_dataset_facets = get_facets_from_redshift_table(
redshift_sql_hook, self.table, self.redshift_data_api_kwargs, self.schema
)

input_dataset_facets = {}
if not self.column_list:
# If column_list is not specified, then we know that input file matches columns of output table.
input_dataset_facets["schema"] = output_dataset_facets["schema"]

dataset_name = self.s3_key
if "*" in dataset_name:
# If wildcard ("*") is used in s3 path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
input_dataset_facets["symlink"] = SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"s3://{self.s3_bucket}", name=dataset_name, type="file")]
redshift_sql_hook, self.table, {}, self.schema
)
dataset_name = Path(dataset_name).parent.as_posix()
if dataset_name == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
dataset_name = "/"

input_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=dataset_name,
facets=input_dataset_facets,
)

output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet(
field_names=[field.name for field in output_dataset_facets["schema"].fields],
input_datasets=[input_dataset],
)

if self.method == "REPLACE":
output_dataset_facets["lifecycleStateChange"] = LifecycleStateChangeDatasetFacet(
Expand All @@ -277,4 +245,9 @@ def get_openlineage_facets_on_complete(self, task_instance):
facets=output_dataset_facets,
)

input_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=self.s3_key,
)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
144 changes: 67 additions & 77 deletions tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from airflow.exceptions import AirflowException
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
from airflow.providers.common.compat.openlineage.facet import LifecycleStateChange
from airflow.providers.common.compat.openlineage.facet import (
DocumentationDatasetFacet,
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
)
from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces


Expand Down Expand Up @@ -502,8 +508,9 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_default(
self, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
Expand All @@ -515,6 +522,11 @@ def test_get_openlineage_facets_on_complete_default(
mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand All @@ -531,33 +543,30 @@ def test_get_openlineage_facets_on_complete_default(
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
# Hook called two times - on operator execution, and on querying data in redshift to fetch schema
assert mock_run.call_count == 2
# Hook called only one time - on operator execution - we mocked querying to fetch schema
assert mock_run.call_count == 1

assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0].name == s3_key
assert lineage.inputs[0].namespace == f"s3://{s3_bucket}"
assert lineage.outputs[0].name == f"database.{schema}.{table}"
assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"

assert lineage.outputs[0].facets.get("schema") is not None
assert lineage.outputs[0].facets.get("columnLineage") is not None

assert lineage.inputs[0].facets.get("schema") is not None
# As method was not overwrite, there should be no lifecycleStateChange facet
assert "lifecycleStateChange" not in lineage.outputs[0].facets
assert lineage.outputs[0].facets == mock_facets
assert lineage.inputs[0].facets == {}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_replace(
self, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
Expand All @@ -569,6 +578,11 @@ def test_get_openlineage_facets_on_complete_replace(
mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand All @@ -586,59 +600,25 @@ def test_get_openlineage_facets_on_complete_replace(
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)

assert (
lineage.outputs[0].facets["lifecycleStateChange"].lifecycleStateChange
== LifecycleStateChange.OVERWRITE
)

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
def test_get_openlineage_facets_on_complete_column_list(
self, mock_run, mock_session, mock_connection, mock_hook
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None

mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)

schema = "schema"
table = "table"
s3_bucket = "bucket"
s3_key = "key"
copy_options = ""

op = S3ToRedshiftOperator(
schema=schema,
table=table,
s3_bucket=s3_bucket,
s3_key=s3_key,
copy_options=copy_options,
column_list=["column1", "column2"],
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0].name == s3_key
assert lineage.inputs[0].namespace == f"s3://{s3_bucket}"
assert lineage.outputs[0].name == f"database.{schema}.{table}"
assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"

assert lineage.outputs[0].facets.get("schema") is not None
assert lineage.inputs[0].facets.get("schema") is None
assert lineage.outputs[0].facets == {
**mock_facets,
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.OVERWRITE
),
}
assert lineage.inputs[0].facets == {}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
Expand All @@ -648,8 +628,9 @@ def test_get_openlineage_facets_on_complete_column_list(
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name",
new_callable=mock.PropertyMock,
)
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_using_redshift_data_api(
self, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook
):
"""
Using the Redshift Data API instead of the SQL-based connection
Expand All @@ -666,6 +647,11 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
mock_rs.describe_statement.return_value = {"Status": "FINISHED"}

mock_rs_region.return_value = "region"
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand All @@ -689,7 +675,7 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
method="REPLACE",
redshift_data_api_kwargs=dict(
database=database,
cluster_identifier=cluster_identifier,
Expand All @@ -705,15 +691,17 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0].name == s3_key
assert lineage.inputs[0].namespace == f"s3://{s3_bucket}"
assert lineage.outputs[0].name == f"database.{schema}.{table}"
assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"

assert lineage.outputs[0].facets.get("schema") is not None
assert lineage.outputs[0].facets.get("columnLineage") is not None

assert lineage.inputs[0].facets.get("schema") is not None
# As method was not overwrite, there should be no lifecycleStateChange facet
assert "lifecycleStateChange" not in lineage.outputs[0].facets
assert lineage.outputs[0].facets == {
**mock_facets,
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.OVERWRITE
),
}
assert lineage.inputs[0].facets == {}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
Expand All @@ -724,8 +712,9 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name",
new_callable=mock.PropertyMock,
)
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
self, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook
):
"""
Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage.
Expand All @@ -745,6 +734,11 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
mock_rs.describe_statement.return_value = {"Status": "FINISHED"}

mock_rs_region.return_value = "region"
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand Down Expand Up @@ -794,13 +788,9 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
op_rs_sql.execute(None)
rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None)

assert rs_sql_lineage.inputs == rs_data_lineage.inputs
assert len(rs_sql_lineage.inputs) == 1
assert len(rs_sql_lineage.outputs) == 1
assert len(rs_data_lineage.outputs) == 1
assert rs_sql_lineage.outputs[0].facets["schema"] == rs_data_lineage.outputs[0].facets["schema"]
assert (
rs_sql_lineage.outputs[0].facets["columnLineage"]
== rs_data_lineage.outputs[0].facets["columnLineage"]
)
assert rs_sql_lineage.outputs[0].name == rs_data_lineage.outputs[0].name
assert rs_sql_lineage.outputs[0].namespace == rs_data_lineage.outputs[0].namespace
assert rs_sql_lineage.inputs == rs_data_lineage.inputs
assert rs_sql_lineage.outputs == rs_data_lineage.outputs
assert rs_sql_lineage.job_facets == rs_data_lineage.job_facets
assert rs_sql_lineage.run_facets == rs_data_lineage.run_facets

0 comments on commit 9af2636

Please sign in to comment.