diff --git a/providers/src/airflow/providers/openlineage/extractors/manager.py b/providers/src/airflow/providers/openlineage/extractors/manager.py index f6d572bae5313..be824335718b1 100644 --- a/providers/src/airflow/providers/openlineage/extractors/manager.py +++ b/providers/src/airflow/providers/openlineage/extractors/manager.py @@ -198,6 +198,8 @@ def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None: except ImportError: return None + if not hasattr(get_hook_lineage_collector(), "has_collected"): + return None if not get_hook_lineage_collector().has_collected: return None diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index a00552eed251f..f41fa5e82f2c7 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -262,9 +262,27 @@ class DagInfo(InfoJsonEncodable): "start_date", "tags", ] - casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None} + casts = {"timetable": lambda dag: DagInfo.serialize_timetable(dag)} renames = {"_dag_id": "dag_id"} + @classmethod + def serialize_timetable(cls, dag): + serialized = dag.timetable.serialize() + if serialized != {} and serialized is not None: + return serialized + if hasattr(dag, "dataset_triggers"): + triggers = dag.dataset_triggers + return { + "dataset_condition": { + "__type": "dataset_all", + "objects": [ + {"__type": "dataset", "uri": trigger.uri, "extra": trigger.extra} + for trigger in triggers + ], + } + } + return {} + class DagRunInfo(InfoJsonEncodable): """Defines encoding DagRun object to JSON.""" diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 12f7ed32a8259..5e4e9b8c0cb5b 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -29,8 +29,10 @@ from pkg_resources import parse_version from airflow.models import DAG as AIRFLOW_DAG, DagModel +from airflow.providers.common.compat.assets import Asset, AssetAlias, AssetAll, AssetAny from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( + DagInfo, InfoJsonEncodable, OpenLineageRedactor, _get_all_packages_installed, @@ -40,11 +42,18 @@ get_fully_qualified_class_name, is_operator_disabled, ) +from airflow.serialization.enums import DagAttributeTypes from airflow.utils import timezone from airflow.utils.log.secrets_masker import _secrets_masker from airflow.utils.state import State -from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, BashOperator +from tests_common.test_utils.compat import ( + AIRFLOW_V_2_8_PLUS, + AIRFLOW_V_2_9_PLUS, + AIRFLOW_V_2_10_PLUS, + AIRFLOW_V_3_0_PLUS, + BashOperator, +) if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType @@ -320,3 +329,111 @@ def test_does_not_include_full_task_info(mock_include_full_task_info): MagicMock(), )["airflow"].task ) + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions") +def test_serialize_timetable(): + from airflow.timetables.simple import AssetTriggeredTimetable + + asset = AssetAny( + Asset("2"), + AssetAlias("example-alias"), + Asset("3"), + AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), + ) + dag = MagicMock() + dag.timetable = AssetTriggeredTimetable(asset) + dag_info = DagInfo(dag) + + assert dag_info.timetable == { + "asset_condition": { + "__type": DagAttributeTypes.ASSET_ANY, + "objects": [ + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"}, + {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"}, + { + "__type": DagAttributeTypes.ASSET_ALL, + "objects": [ + {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"}, + ], + }, + ], + } + } + + +@pytest.mark.skipif( + not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS, + reason="This test checks serialization only in 2.10 conditions", +) +def test_serialize_timetable_2_10(): + from airflow.timetables.simple import DatasetTriggeredTimetable + + asset = AssetAny( + Asset("2"), + AssetAlias("example-alias"), + Asset("3"), + AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), + ) + + dag = MagicMock() + dag.timetable = DatasetTriggeredTimetable(asset) + dag_info = DagInfo(dag) + + assert dag_info.timetable == { + "dataset_condition": { + "__type": DagAttributeTypes.ASSET_ANY, + "objects": [ + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"}, + {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"}, + { + "__type": DagAttributeTypes.ASSET_ALL, + "objects": [ + {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"}, + ], + }, + ], + } + } + + +@pytest.mark.skipif( + not AIRFLOW_V_2_9_PLUS or AIRFLOW_V_2_10_PLUS, + reason="This test checks serialization only in 2.9 conditions", +) +def test_serialize_timetable_2_9(): + dag = MagicMock() + dag.dataset_triggers = [Asset("a"), Asset("b")] + dag_info = DagInfo(dag) + assert dag_info.timetable == { + "dataset_condition": { + "__type": DagAttributeTypes.ASSET_ALL, + "objects": [ + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "a"}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "b"}, + ], + } + } + + +@pytest.mark.skipif( + not AIRFLOW_V_2_8_PLUS or AIRFLOW_V_2_9_PLUS, + reason="This test checks serialization only in 2.8 conditions", +) +def test_serialize_timetable_2_8(): + dag = MagicMock() + dag.dataset_triggers = [Asset("a"), Asset("b")] + dag_info = DagInfo(dag) + assert dag_info.timetable == { + "dataset_condition": { + "__type": DagAttributeTypes.ASSET_ALL, + "objects": [ + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "a"}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "b"}, + ], + } + }