From c4c5c3e76dfa48c921f73c61506383e960a15f04 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 8 Nov 2024 09:49:01 -0500 Subject: [PATCH] AIP-82 Save references between assets and triggers --- airflow/assets/__init__.py | 34 +++++++++- airflow/dag_processing/collection.py | 74 +++++++++++++++++++++ airflow/decorators/assets.py | 3 + airflow/models/dag.py | 1 + airflow/models/trigger.py | 3 +- task_sdk/src/airflow/sdk/definitions/dag.py | 6 +- 6 files changed, 113 insertions(+), 8 deletions(-) diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index f1d36ac12b73..c04abfbb0789 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -36,6 +36,9 @@ from sqlalchemy.orm.session import Session + from airflow.triggers.base import BaseTrigger + + __all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"] @@ -276,20 +279,43 @@ class Asset(os.PathLike, BaseAsset): uri: str group: str extra: dict[str, Any] + watchers: list[BaseTrigger] asset_type: ClassVar[str] = "asset" __version__: ClassVar[int] = 1 @overload - def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + name: str, + uri: str, + *, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """Canonical; both name and uri are provided.""" @overload - def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + name: str, + *, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """It's possible to only provide the name, either by keyword or as the only positional argument.""" @overload - def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None: + def __init__( + self, + *, + uri: str, + group: str = "", + extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, + ) -> None: """It's possible to only provide the URI as a keyword argument.""" def __init__( @@ -299,6 +325,7 @@ def __init__( *, group: str = "", extra: dict | None = None, + watchers: list[BaseTrigger] | None = None, ) -> None: if name is None and uri is None: raise TypeError("Asset() requires either 'name' or 'uri'") @@ -311,6 +338,7 @@ def __init__( self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri)) self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type self.extra = _set_extra_default(extra) + self.watchers = watchers or [] def __fspath__(self) -> str: return self.uri diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f608900ee76e..657ce52e04c7 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -45,6 +45,7 @@ ) from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag from airflow.models.dagrun import DagRun +from airflow.models.trigger import Trigger from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.timezone import utcnow from airflow.utils.types import DagRunType @@ -55,6 +56,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select + from airflow.triggers.base import BaseTrigger from airflow.typing_compat import Self log = logging.getLogger(__name__) @@ -425,3 +427,75 @@ def add_task_asset_references( for task_id, asset_id in referenced_outlets if (task_id, asset_id) not in orm_refs ) + + def add_asset_trigger_references( + self, dags: dict[str, DagModel], assets: dict[tuple[str, str], AssetModel], *, session: Session + ) -> None: + # Update references from assets being used + refs_to_add: dict[tuple[str, str], set[str]] = {} + refs_to_remove: dict[tuple[str, str], set[str]] = {} + triggers: dict[str, BaseTrigger] = {} + for name_uri, asset in self.assets.items(): + asset_model = assets[name_uri] + trigger_class_path_to_trigger_dict: dict[str, BaseTrigger] = { + trigger.serialize()[0]: trigger for trigger in asset.watchers + } + triggers.update(trigger_class_path_to_trigger_dict) + + trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_trigger_dict.keys()) + trigger_class_paths_from_asset_model: set[str] = { + trigger.classpath for trigger in asset_model.triggers + } + + # Optimization: no diff between the DB and DAG definitions, no update needed + if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model: + continue + + diff_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model + diff_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset + if diff_to_add: + refs_to_add[name_uri] = diff_to_add + if diff_to_remove: + refs_to_remove[name_uri] = diff_to_remove + + if refs_to_add: + all_classpaths = {classpath for classpaths in refs_to_add.values() for classpath in classpaths} + orm_triggers: dict[str, Trigger] = { + trigger.classpath: trigger + for trigger in session.scalars(select(Trigger).where(Trigger.classpath.in_(all_classpaths))) + } + + # Create new triggers + new_trigger_models = [ + trigger + for trigger in [ + Trigger.from_object(triggers[classpath]) + for classpath in all_classpaths + if classpath not in orm_triggers + ] + ] + session.add_all(new_trigger_models) + orm_triggers.update((trigger.classpath, trigger) for trigger in new_trigger_models) + + # Add new references + for name_uri, classpaths in refs_to_add.items(): + asset_model = assets[name_uri] + asset_model.triggers.extend( + [orm_triggers.get(trigger_class_path) for trigger_class_path in classpaths] + ) + + if refs_to_remove: + # Remove old references + for name_uri, classpaths in refs_to_remove.items(): + asset_model = assets[name_uri] + asset_model.triggers = [ + trigger for trigger in asset_model.triggers if trigger.classpath not in classpaths + ] + + # Remove references from assets no longer used + orphan_assets = session.scalars( + select(AssetModel).filter(~AssetModel.consuming_dags.any()).filter(AssetModel.triggers.any()) + ) + for asset_model in orphan_assets: + if (asset_model.name, asset_model.uri) not in self.assets: + asset_model.triggers = [] diff --git a/airflow/decorators/assets.py b/airflow/decorators/assets.py index 2f5052c2d5c9..0bf0a643a40f 100644 --- a/airflow/decorators/assets.py +++ b/airflow/decorators/assets.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from airflow.io.path import ObjectStoragePath + from airflow.triggers.base import BaseTrigger class _AssetMainOperator(PythonOperator): @@ -116,6 +117,7 @@ class asset: uri: str | ObjectStoragePath | None = None group: str = "" extra: dict[str, Any] = attrs.field(factory=dict) + watchers: list[BaseTrigger] = attrs.field(factory=list) def __call__(self, f: Callable) -> AssetDefinition: if (name := f.__name__) != f.__qualname__: @@ -126,6 +128,7 @@ def __call__(self, f: Callable) -> AssetDefinition: uri=name if self.uri is None else str(self.uri), group=self.group, extra=self.extra, + watchers=self.watchers, function=f, schedule=self.schedule, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e48ec0a9a9c5..c8d576f7afcf 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1842,6 +1842,7 @@ def bulk_write_to_db( asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_asset_trigger_references(orm_dags, orm_assets, session=session) session.flush() @provide_session diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 27868daf083f..dd90b0ac67fb 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -38,7 +38,6 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select - from airflow.serialization.pydantic.trigger import TriggerPydantic from airflow.triggers.base import BaseTrigger @@ -141,7 +140,7 @@ def rotate_fernet_key(self): @classmethod @internal_api_call @provide_session - def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic: + def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger: """Alternative constructor that creates a trigger row based directly off of a Trigger object.""" classpath, kwargs = trigger.serialize() return cls(classpath=classpath, kwargs=kwargs) diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 2573c62eb312..a28209f181ca 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -557,9 +557,9 @@ def __lt__(self, other): def __hash__(self): hash_components: list[Any] = [type(self)] for c in _DAG_HASH_ATTRS: - # task_ids returns a list and lists can't be hashed - if c == "task_ids": - val = tuple(self.task_dict) + # If it is a list, convert to tuple because lists can't be hashed + if isinstance(getattr(self, c, None), list): + val = tuple(getattr(self, c)) else: val = getattr(self, c, None) try: