Skip to content

Commit

Permalink
AIP-82 Save references between assets and triggers
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Nov 14, 2024
1 parent 87445bf commit ae6cfc5
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 8 deletions.
34 changes: 31 additions & 3 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"]


Expand Down Expand Up @@ -276,20 +279,43 @@ class Asset(os.PathLike, BaseAsset):
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger]

asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1

@overload
def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
uri: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""Canonical; both name and uri are provided."""

@overload
def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

@overload
def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
*,
uri: str,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

def __init__(
Expand All @@ -299,6 +325,7 @@ def __init__(
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand All @@ -311,6 +338,7 @@ def __init__(
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

def __fspath__(self) -> str:
return self.uri
Expand Down
72 changes: 72 additions & 0 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag
from airflow.models.dagrun import DagRun
from airflow.models.trigger import Trigger
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
Expand All @@ -55,6 +56,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,3 +427,73 @@ def add_task_asset_references(
for task_id, asset_id in referenced_outlets
if (task_id, asset_id) not in orm_refs
)

def add_asset_trigger_references(
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}
for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
trigger_class_path_to_trigger_dict: dict[str, BaseTrigger] = {
trigger.serialize()[0]: trigger for trigger in asset.watchers
}
triggers.update(trigger_class_path_to_trigger_dict)

trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_trigger_dict.keys())
trigger_class_paths_from_asset_model: set[str] = {
trigger.classpath for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model:
continue

diff_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model
diff_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove

if refs_to_add:
all_classpaths = {classpath for classpaths in refs_to_add.values() for classpath in classpaths}
orm_triggers: dict[str, Trigger] = {
trigger.classpath: trigger
for trigger in session.scalars(select(Trigger).where(Trigger.classpath.in_(all_classpaths)))
}

# Create new triggers
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[classpath])
for classpath in all_classpaths
if classpath not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update((trigger.classpath, trigger) for trigger in new_trigger_models)

# Add new references
for name_uri, classpaths in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
[orm_triggers.get(trigger_class_path) for trigger_class_path in classpaths]
)

if refs_to_remove:
# Remove old references
for name_uri, classpaths in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger for trigger in asset_model.triggers if trigger.classpath not in classpaths
]

# Remove references from assets no longer used
all_assets = session.scalars(select(AssetModel))
for asset_model in all_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,7 @@ def bulk_write_to_db(
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_asset_trigger_references(orm_assets, session=session)
session.flush()

@provide_session
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.serialization.pydantic.trigger import TriggerPydantic
from airflow.triggers.base import BaseTrigger


Expand Down Expand Up @@ -141,7 +140,7 @@ def rotate_fernet_key(self):
@classmethod
@internal_api_call
@provide_session
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic:
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger:
"""Alternative constructor that creates a trigger row based directly off of a Trigger object."""
classpath, kwargs = trigger.serialize()
return cls(classpath=classpath, kwargs=kwargs)
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def __lt__(self, other):
def __hash__(self):
hash_components: list[Any] = [type(self)]
for c in _DAG_HASH_ATTRS:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict)
# If it is a list, convert to tuple because lists can't be hashed
if isinstance(getattr(self, c, None), list):
val = tuple(getattr(self, c))
else:
val = getattr(self, c, None)
try:
Expand Down

0 comments on commit ae6cfc5

Please sign in to comment.