From 598bc69352bdf9f951bbe7f97a4b854f3286c663 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 8 Nov 2024 09:49:01 -0500 Subject: [PATCH] AIP-82 Save references between assets and triggers --- airflow/assets/__init__.py | 36 +++++++++++++++-- airflow/dag_processing/collection.py | 45 +++++++++++++++++++++ airflow/models/dag.py | 1 + task_sdk/src/airflow/sdk/definitions/dag.py | 6 +-- 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index 58256929948a8..5e6a3952fa716 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -36,6 +36,8 @@ from sqlalchemy.orm.session import Session + from airflow.triggers.base import BaseTrigger + from airflow.configuration import conf @@ -281,20 +283,43 @@ class Asset(os.PathLike, BaseAsset): uri: str group: str extra: dict[str, Any] + watchers: list[BaseTrigger] = [] asset_type: ClassVar[str] = "" __version__: ClassVar[int] = 1 @overload - def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + name: str, + uri: str, + *, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """Canonical; both name and uri are provided.""" @overload - def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + name: str, + *, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """It's possible to only provide the name, either by keyword or as the only positional argument.""" @overload - def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + *, + uri: str, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """It's possible to only provide the URI as a keyword argument.""" def __init__( @@ -304,6 +329,7 @@ def __init__( *, group: str = "", extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, ) -> None: if name is None and uri is None: raise TypeError("Asset() requires either 'name' or 'uri'") @@ -316,10 +342,14 @@ def __init__( self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri)) self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type self.extra = _set_extra_default(extra) + self.watchers = watchers or [] def __fspath__(self) -> str: return self.uri + def __hash__(self) -> int: + return hash(self.uri) + @property def normalized_uri(self) -> str | None: """ diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f608900ee76e1..2da995865ec4f 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -36,6 +36,7 @@ from airflow.assets import Asset, AssetAlias from airflow.assets.manager import asset_manager +from airflow.models import Trigger from airflow.models.asset import ( AssetAliasModel, AssetModel, @@ -55,6 +56,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select + from airflow.triggers.base import BaseTrigger from airflow.typing_compat import Self log = logging.getLogger(__name__) @@ -425,3 +427,46 @@ def add_task_asset_references( for task_id, asset_id in referenced_outlets if (task_id, asset_id) not in orm_refs ) + + def add_asset_trigger_references( + self, assets: dict[tuple[str, str], AssetModel], *, session: Session + ) -> None: + for name_uri, asset in self.assets.items(): + asset_model = assets[name_uri] + trigger_class_path_to_asset_dict: dict[str, BaseTrigger] = { + trigger.serialize()[0]: trigger for trigger in asset.watchers + } + + trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_asset_dict.keys()) + trigger_class_paths_from_asset_model: set[str] = { + trigger.classpath for trigger in asset_model.triggers + } + + # Optimization: no diff between the DB and DAG definitions, no update needed + if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model: + continue + + refs_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model + refs_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset + + session.scalars( + select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(self.assets)) + ) + + # Remove old references + asset_model.triggers = [ + trigger for trigger in asset_model.triggers if trigger.classpath not in refs_to_remove + ] + + # Add new references + for trigger_class_path in refs_to_add: + trigger_model = session.scalar( + select(Trigger).where(Trigger.classpath == trigger_class_path).limit(1) + ) + + # Create the trigger in the DB if it does not exist + if not trigger_model: + trigger_model = Trigger.from_object(trigger_class_path_to_asset_dict[trigger_class_path]) + session.add(trigger_model) + + asset_model.triggers.append(trigger_model) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e6a67c6ad7e5e..56230cf984680 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1844,6 +1844,7 @@ def bulk_write_to_db( asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_asset_trigger_references(orm_assets, session=session) session.flush() @provide_session diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 9a124d237ed57..d3b618a506e88 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -557,9 +557,9 @@ def __lt__(self, other): def __hash__(self): hash_components: list[Any] = [type(self)] for c in _DAG_HASH_ATTRS: - # task_ids returns a list and lists can't be hashed - if c == "task_ids": - val = tuple(self.task_dict) + # If it is a list, convert to tuple because lists can't be hashed + if isinstance(getattr(self, c, None), list): + val = tuple(getattr(self, c)) else: val = getattr(self, c, None) try: