Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serialize asset/dataset timetable conditions in OpenLineage info also for older supported Airflow 2 versions #43434

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@

_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
_IS_AIRFLOW_2_9_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0")
_IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0")

# dataset is renamed to asset since Airflow 3.0
from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails
from airflow.datasets import Dataset as Asset

if _IS_AIRFLOW_2_8_OR_HIGHER:
from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails

if _IS_AIRFLOW_2_9_OR_HIGHER:
from airflow.datasets import (
DatasetAll as AssetAll,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
except ImportError:
return None

if not hasattr(get_hook_lineage_collector(), "has_collected"):
return None
if not get_hook_lineage_collector().has_collected:
return None

Expand Down
24 changes: 23 additions & 1 deletion providers/src/airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,31 @@ class DagInfo(InfoJsonEncodable):
"start_date",
"tags",
]
casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None}
casts = {"timetable": lambda dag: DagInfo.serialize_timetable(dag)}
renames = {"_dag_id": "dag_id"}

@classmethod
def serialize_timetable(cls, dag: DAG) -> dict[str, Any]:
serialized = dag.timetable.serialize()
if serialized != {} and serialized is not None:
return serialized
if (
hasattr(dag, "dataset_triggers")
and isinstance(dag.dataset_triggers, list)
and len(dag.dataset_triggers)
):
triggers = dag.dataset_triggers
return {
"dataset_condition": {
mobuchowski marked this conversation as resolved.
Show resolved Hide resolved
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "uri": trigger.uri, "extra": trigger.extra}
for trigger in triggers
],
}
}
return {}


class DagRunInfo(InfoJsonEncodable):
"""Defines encoding DagRun object to JSON."""
Expand Down
125 changes: 124 additions & 1 deletion providers/tests/openlineage/plugins/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from pkg_resources import parse_version

from airflow.models import DAG as AIRFLOW_DAG, DagModel
from airflow.providers.common.compat.assets import Asset
from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet
from airflow.providers.openlineage.utils.utils import (
DagInfo,
InfoJsonEncodable,
OpenLineageRedactor,
_get_all_packages_installed,
Expand All @@ -40,11 +42,18 @@
get_fully_qualified_class_name,
is_operator_disabled,
)
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils import timezone
from airflow.utils.log.secrets_masker import _secrets_masker
from airflow.utils.state import State

from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, BashOperator
from tests_common.test_utils.compat import (
AIRFLOW_V_2_8_PLUS,
AIRFLOW_V_2_9_PLUS,
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
BashOperator,
)

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
Expand Down Expand Up @@ -322,3 +331,117 @@ def test_does_not_include_full_task_info(mock_include_full_task_info):
MagicMock(),
)["airflow"].task
)


@pytest.mark.db_test
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions")
def test_serialize_timetable():
from airflow.providers.common.compat.assets import AssetAlias, AssetAll, AssetAny
from airflow.timetables.simple import AssetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)
dag = MagicMock()
dag.timetable = AssetTriggeredTimetable(asset)
dag_info = DagInfo(dag)

assert dag_info.timetable == {
"asset_condition": {
"__type": DagAttributeTypes.ASSET_ANY,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"},
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"},
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"},
],
},
],
}
}


@pytest.mark.db_test
@pytest.mark.skipif(
not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS,
reason="This test checks serialization only in 2.10 conditions",
)
def test_serialize_timetable_2_10():
from airflow.providers.common.compat.assets import AssetAlias, AssetAll, AssetAny
from airflow.timetables.simple import DatasetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)

dag = MagicMock()
dag.timetable = DatasetTriggeredTimetable(asset)
dag_info = DagInfo(dag)

assert dag_info.timetable == {
"dataset_condition": {
"__type": DagAttributeTypes.DATASET_ANY,
"objects": [
{"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "2"},
{"__type": DagAttributeTypes.DATASET_ANY, "objects": []},
{"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "3"},
{
"__type": DagAttributeTypes.DATASET_ALL,
"objects": [
{"__type": DagAttributeTypes.DATASET_ANY, "objects": []},
{"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "4"},
],
},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_9_PLUS or AIRFLOW_V_2_10_PLUS,
reason="This test checks serialization only in 2.9 conditions",
)
def test_serialize_timetable_2_9():
dag = MagicMock()
dag.timetable.serialize.return_value = {}
dag.dataset_triggers = [Asset("a"), Asset("b")]
dag_info = DagInfo(dag)
assert dag_info.timetable == {
"dataset_condition": {
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "extra": None, "uri": "a"},
{"__type": "dataset", "extra": None, "uri": "b"},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_8_PLUS or AIRFLOW_V_2_9_PLUS,
reason="This test checks serialization only in 2.8 conditions",
)
def test_serialize_timetable_2_8():
dag = MagicMock()
dag.timetable.serialize.return_value = {}
dag.dataset_triggers = [Asset("a"), Asset("b")]
dag_info = DagInfo(dag)
assert dag_info.timetable == {
"dataset_condition": {
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "extra": None, "uri": "a"},
{"__type": "dataset", "extra": None, "uri": "b"},
],
}
}