Skip to content

Commit

Permalink
AIP-82 Save references between assets and triggers
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Nov 8, 2024
1 parent 24b2369 commit 598bc69
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
36 changes: 33 additions & 3 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


from airflow.configuration import conf

Expand Down Expand Up @@ -281,20 +283,43 @@ class Asset(os.PathLike, BaseAsset):
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger] = []

asset_type: ClassVar[str] = ""
__version__: ClassVar[int] = 1

@overload
def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
uri: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""Canonical; both name and uri are provided."""

@overload
def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

@overload
def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
*,
uri: str,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

def __init__(
Expand All @@ -304,6 +329,7 @@ def __init__(
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand All @@ -316,10 +342,14 @@ def __init__(
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

def __fspath__(self) -> str:
return self.uri

def __hash__(self) -> int:
return hash(self.uri)

@property
def normalized_uri(self) -> str | None:
"""
Expand Down
45 changes: 45 additions & 0 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from airflow.assets import Asset, AssetAlias
from airflow.assets.manager import asset_manager
from airflow.models import Trigger
from airflow.models.asset import (
AssetAliasModel,
AssetModel,
Expand All @@ -55,6 +56,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,3 +427,46 @@ def add_task_asset_references(
for task_id, asset_id in referenced_outlets
if (task_id, asset_id) not in orm_refs
)

def add_asset_trigger_references(
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
trigger_class_path_to_asset_dict: dict[str, BaseTrigger] = {
trigger.serialize()[0]: trigger for trigger in asset.watchers
}

trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_asset_dict.keys())
trigger_class_paths_from_asset_model: set[str] = {
trigger.classpath for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model:
continue

refs_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model
refs_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset

session.scalars(
select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(self.assets))
)

# Remove old references
asset_model.triggers = [
trigger for trigger in asset_model.triggers if trigger.classpath not in refs_to_remove
]

# Add new references
for trigger_class_path in refs_to_add:
trigger_model = session.scalar(
select(Trigger).where(Trigger.classpath == trigger_class_path).limit(1)
)

# Create the trigger in the DB if it does not exist
if not trigger_model:
trigger_model = Trigger.from_object(trigger_class_path_to_asset_dict[trigger_class_path])
session.add(trigger_model)

asset_model.triggers.append(trigger_model)
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,7 @@ def bulk_write_to_db(
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_asset_trigger_references(orm_assets, session=session)
session.flush()

@provide_session
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def __lt__(self, other):
def __hash__(self):
hash_components: list[Any] = [type(self)]
for c in _DAG_HASH_ATTRS:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict)
# If it is a list, convert to tuple because lists can't be hashed
if isinstance(getattr(self, c, None), list):
val = tuple(getattr(self, c))
else:
val = getattr(self, c, None)
try:
Expand Down

0 comments on commit 598bc69

Please sign in to comment.