From c2660ea6f284188f2b7fc2d5ac8d8193cb51f9f9 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 8 Nov 2024 09:49:01 -0500 Subject: [PATCH] AIP-82 Save references between assets and triggers --- airflow/assets/__init__.py | 34 +++++++++- airflow/dag_processing/collection.py | 72 +++++++++++++++++++++ airflow/models/dag.py | 1 + airflow/models/trigger.py | 3 +- task_sdk/src/airflow/sdk/definitions/dag.py | 6 +- 5 files changed, 108 insertions(+), 8 deletions(-) diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index f1d36ac12b73..c04abfbb0789 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -36,6 +36,9 @@ from sqlalchemy.orm.session import Session + from airflow.triggers.base import BaseTrigger + + __all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"] @@ -276,20 +279,43 @@ class Asset(os.PathLike, BaseAsset): uri: str group: str extra: dict[str, Any] + watchers: list[BaseTrigger] asset_type: ClassVar[str] = "asset" __version__: ClassVar[int] = 1 @overload - def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + name: str, + uri: str, + *, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """Canonical; both name and uri are provided.""" @overload - def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + name: str, + *, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """It's possible to only provide the name, either by keyword or as the only positional argument.""" @overload - def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + *, + uri: str, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """It's possible to only provide the URI as a keyword argument.""" def __init__( @@ -299,6 +325,7 @@ def __init__( *, group: str = "", extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, ) -> None: if name is None and uri is None: raise TypeError("Asset() requires either 'name' or 'uri'") @@ -311,6 +338,7 @@ def __init__( self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri)) self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type self.extra = _set_extra_default(extra) + self.watchers = watchers or [] def __fspath__(self) -> str: return self.uri diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f608900ee76e..7a496756b605 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -36,6 +36,7 @@ from airflow.assets import Asset, AssetAlias from airflow.assets.manager import asset_manager +from airflow.models import Trigger from airflow.models.asset import ( AssetAliasModel, AssetModel, @@ -55,6 +56,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select + from airflow.triggers.base import BaseTrigger from airflow.typing_compat import Self log = logging.getLogger(__name__) @@ -425,3 +427,73 @@ def add_task_asset_references( for task_id, asset_id in referenced_outlets if (task_id, asset_id) not in orm_refs ) + + def add_asset_trigger_references( + self, assets: dict[tuple[str, str], AssetModel], *, session: Session + ) -> None: + # Update references from assets being used + refs_to_add: dict[tuple[str, str], set[str]] = {} + refs_to_remove: dict[tuple[str, str], set[str]] = {} + triggers: dict[str, BaseTrigger] = {} + for name_uri, asset in self.assets.items(): + asset_model = assets[name_uri] + trigger_class_path_to_trigger_dict: dict[str, BaseTrigger] = { + trigger.serialize()[0]: trigger for trigger in asset.watchers + } + triggers.update(trigger_class_path_to_trigger_dict) + + trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_trigger_dict.keys()) + trigger_class_paths_from_asset_model: set[str] = { + trigger.classpath for trigger in asset_model.triggers + } + + # Optimization: no diff between the DB and DAG definitions, no update needed + if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model: + continue + + diff_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model + diff_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset + if diff_to_add: + refs_to_add[name_uri] = diff_to_add + if diff_to_remove: + refs_to_remove[name_uri] = diff_to_remove + + if refs_to_add: + all_classpaths = {classpath for classpaths in refs_to_add.values() for classpath in classpaths} + orm_triggers: dict[str, Trigger] = { + trigger.classpath: trigger + for trigger in session.scalars(select(Trigger).where(Trigger.classpath.in_(all_classpaths))) + } + + # Create new triggers + new_trigger_models = [ + trigger + for trigger in [ + Trigger.from_object(triggers[classpath]) + for classpath in all_classpaths + if classpath not in orm_triggers + ] + ] + session.add_all(new_trigger_models) + orm_triggers.update((trigger.classpath, trigger) for trigger in new_trigger_models) + + # Add new references + for name_uri, classpaths in refs_to_add.items(): + asset_model = assets[name_uri] + asset_model.triggers.extend( + [orm_triggers.get(trigger_class_path) for trigger_class_path in classpaths] + ) + + if refs_to_remove: + # Remove old references + for name_uri, classpaths in refs_to_remove.items(): + asset_model = assets[name_uri] + asset_model.triggers = [ + trigger for trigger in asset_model.triggers if trigger.classpath not in classpaths + ] + + # Remove references from assets no longer used + all_assets = session.scalars(select(AssetModel)) + for asset_model in all_assets: + if (asset_model.name, asset_model.uri) not in self.assets: + asset_model.triggers = [] diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e48ec0a9a9c5..8a7a2c3ad44c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1842,6 +1842,7 @@ def bulk_write_to_db( asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_asset_trigger_references(orm_assets, session=session) session.flush() @provide_session diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 27868daf083f..dd90b0ac67fb 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -38,7 +38,6 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select - from airflow.serialization.pydantic.trigger import TriggerPydantic from airflow.triggers.base import BaseTrigger @@ -141,7 +140,7 @@ def rotate_fernet_key(self): @classmethod @internal_api_call @provide_session - def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic: + def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger: """Alternative constructor that creates a trigger row based directly off of a Trigger object.""" classpath, kwargs = trigger.serialize() return cls(classpath=classpath, kwargs=kwargs) diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 2573c62eb312..a28209f181ca 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -557,9 +557,9 @@ def __lt__(self, other): def __hash__(self): hash_components: list[Any] = [type(self)] for c in _DAG_HASH_ATTRS: - # task_ids returns a list and lists can't be hashed - if c == "task_ids": - val = tuple(self.task_dict) + # If it is a list, convert to tuple because lists can't be hashed + if isinstance(getattr(self, c, None), list): + val = tuple(getattr(self, c)) else: val = getattr(self, c, None) try: