Skip to content

Commit

Permalink
serialize asset/dataset timetable conditions in OpenLineage info also…
Browse files Browse the repository at this point in the history
… for Airflow 2

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Oct 27, 2024
1 parent e9192f5 commit 7537e83
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
except ImportError:
return None

if not hasattr(get_hook_lineage_collector(), "has_collected"):
return None
if not get_hook_lineage_collector().has_collected:
return None

Expand Down
20 changes: 19 additions & 1 deletion providers/src/airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,27 @@ class DagInfo(InfoJsonEncodable):
"start_date",
"tags",
]
casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None}
casts = {"timetable": lambda dag: DagInfo.serialize_timetable(dag)}
renames = {"_dag_id": "dag_id"}

@classmethod
def serialize_timetable(cls, dag):
serialized = dag.timetable.serialize()
if serialized != {} and serialized is not None:
return serialized
if hasattr(dag, "dataset_triggers"):
triggers = dag.dataset_triggers
return {
"dataset_condition": {
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "uri": trigger.uri, "extra": trigger.extra}
for trigger in triggers
],
}
}
return {}


class DagRunInfo(InfoJsonEncodable):
"""Defines encoding DagRun object to JSON."""
Expand Down
119 changes: 118 additions & 1 deletion providers/tests/openlineage/plugins/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from pkg_resources import parse_version

from airflow.models import DAG as AIRFLOW_DAG, DagModel
from airflow.providers.common.compat.assets import Asset, AssetAlias, AssetAll, AssetAny
from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet
from airflow.providers.openlineage.utils.utils import (
DagInfo,
InfoJsonEncodable,
OpenLineageRedactor,
_get_all_packages_installed,
Expand All @@ -40,11 +42,18 @@
get_fully_qualified_class_name,
is_operator_disabled,
)
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils import timezone
from airflow.utils.log.secrets_masker import _secrets_masker
from airflow.utils.state import State

from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, BashOperator
from tests_common.test_utils.compat import (
AIRFLOW_V_2_8_PLUS,
AIRFLOW_V_2_9_PLUS,
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
BashOperator,
)

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
Expand Down Expand Up @@ -320,3 +329,111 @@ def test_does_not_include_full_task_info(mock_include_full_task_info):
MagicMock(),
)["airflow"].task
)


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions")
def test_serialize_timetable():
from airflow.timetables.simple import AssetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)
dag = MagicMock()
dag.timetable = AssetTriggeredTimetable(asset)
dag_info = DagInfo(dag)

assert dag_info.timetable == {
"asset_condition": {
"__type": DagAttributeTypes.ASSET_ANY,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"},
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"},
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"},
],
},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS,
reason="This test checks serialization only in 2.10 conditions",
)
def test_serialize_timetable_2_10():
from airflow.timetables.simple import DatasetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)

dag = MagicMock()
dag.timetable = DatasetTriggeredTimetable(asset)
dag_info = DagInfo(dag)

assert dag_info.timetable == {
"dataset_condition": {
"__type": DagAttributeTypes.ASSET_ANY,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"},
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"},
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"},
],
},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_9_PLUS or AIRFLOW_V_2_10_PLUS,
reason="This test checks serialization only in 2.9 conditions",
)
def test_serialize_timetable_2_9():
dag = MagicMock()
dag.dataset_triggers = [Asset("a"), Asset("b")]
dag_info = DagInfo(dag)
assert dag_info.timetable == {
"dataset_condition": {
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "a"},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "b"},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_8_PLUS or AIRFLOW_V_2_9_PLUS,
reason="This test checks serialization only in 2.8 conditions",
)
def test_serialize_timetable_2_8():
dag = MagicMock()
dag.dataset_triggers = [Asset("a"), Asset("b")]
dag_info = DagInfo(dag)
assert dag_info.timetable == {
"dataset_condition": {
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "a"},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "b"},
],
}
}

0 comments on commit 7537e83

Please sign in to comment.