Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-82 Save references between assets and triggers #43826

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"]


Expand Down Expand Up @@ -276,20 +279,43 @@ class Asset(os.PathLike, BaseAsset):
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger]

asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1

@overload
def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
uri: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if watcher is a good name for this. What do we expect this to do? If I understand AIP-82 correctly, an external event would fire the trigger, and the trigger would create events for assets associated to it.

Assuming my understanding is correct, the triggers here are not watchers of the asset; rather, the asset watches the triggers. The relationship is the other way around. So it is probably better to call this watch instead? Or maybe this attribute should live on the trigger instead, something like

asset = Asset("example_asset_watchers")

trigger = SqsSensorTrigger(sqs_queue="my_queue", trigger=[asset])

DAG(..., schedule=[asset])

Tell me what you think on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming ... so hard haha. I see your point.

The reason why I called it watchers is because the triggers will watch some external resource and send event on updates. In that sense, to me, the triggers are watchers. I am not strongly again watch if you think it makes more sense. To be very honest, between watchers and watch I dont mind, I think the both of them makes sense.

However, I definitely want the attribute on the asset class, I think it makes more sense and a more deliberate choice for the user to say, I have this asset and I want this asset to be updated when these triggers fire.

) -> None:
"""Canonical; both name and uri are provided."""

@overload
def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

@overload
def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
*,
uri: str,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

def __init__(
Expand All @@ -299,6 +325,7 @@ def __init__(
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand All @@ -311,6 +338,7 @@ def __init__(
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

def __fspath__(self) -> str:
return self.uri
Expand Down
74 changes: 74 additions & 0 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag
from airflow.models.dagrun import DagRun
from airflow.models.trigger import Trigger
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
Expand All @@ -55,6 +56,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,3 +427,75 @@ def add_task_asset_references(
for task_id, asset_id in referenced_outlets
if (task_id, asset_id) not in orm_refs
)

def add_asset_trigger_references(
self, dags: dict[str, DagModel], assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}
for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
trigger_class_path_to_trigger_dict: dict[str, BaseTrigger] = {
trigger.serialize()[0]: trigger for trigger in asset.watchers
}
triggers.update(trigger_class_path_to_trigger_dict)

trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_trigger_dict.keys())
trigger_class_paths_from_asset_model: set[str] = {
trigger.classpath for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model:
continue

diff_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model
diff_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove

if refs_to_add:
all_classpaths = {classpath for classpaths in refs_to_add.values() for classpath in classpaths}
orm_triggers: dict[str, Trigger] = {
trigger.classpath: trigger
for trigger in session.scalars(select(Trigger).where(Trigger.classpath.in_(all_classpaths)))
}

# Create new triggers
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[classpath])
for classpath in all_classpaths
if classpath not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update((trigger.classpath, trigger) for trigger in new_trigger_models)

# Add new references
for name_uri, classpaths in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
[orm_triggers.get(trigger_class_path) for trigger_class_path in classpaths]
)

if refs_to_remove:
# Remove old references
for name_uri, classpaths in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger for trigger in asset_model.triggers if trigger.classpath not in classpaths
]

# Remove references from assets no longer used
orphan_assets = session.scalars(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uranusjr do we need to check AssetActive here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An asset without an AssetActive entry is not referenced anywhere, and trigerring an event to such an asset will therefore simply do nothing. So not checking AssetActive here is not useful in practice, but maybe theoratically a possibility? It depends on what we want the user to be able to do, I guess. @vincbeck Do you think a user should be able to trigger an event on an asset that does not actually exist in any DAGs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I did not know that notion of AssetActive, maybe I could use it.

Do you think a user should be able to trigger an event on an asset that does not actually exist in any DAGs?

Absolutely not, that's what I am doing (or trying to do) here but even further. orphan_assets contains all the assets not used by any DAGs as schedule. In other words, no DAG use an asset in orphan_assets as schedule condition. I am removing all references from these assets since they are not used to schedule DAG. The way I understand it is, all assets with an AssetActive entry right is a subset of orphan_assets ?

select(AssetModel).filter(~AssetModel.consuming_dags.any()).filter(AssetModel.triggers.any())
)
for asset_model in orphan_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []
3 changes: 3 additions & 0 deletions airflow/decorators/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath
from airflow.triggers.base import BaseTrigger


class _AssetMainOperator(PythonOperator):
Expand Down Expand Up @@ -116,6 +117,7 @@ class asset:
uri: str | ObjectStoragePath | None = None
group: str = ""
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

def __call__(self, f: Callable) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
Expand All @@ -126,6 +128,7 @@ def __call__(self, f: Callable) -> AssetDefinition:
uri=name if self.uri is None else str(self.uri),
group=self.group,
extra=self.extra,
watchers=self.watchers,
function=f,
schedule=self.schedule,
)
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,7 @@ def bulk_write_to_db(
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_asset_trigger_references(orm_dags, orm_assets, session=session)
session.flush()

@provide_session
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.serialization.pydantic.trigger import TriggerPydantic
from airflow.triggers.base import BaseTrigger


Expand Down Expand Up @@ -141,7 +140,7 @@ def rotate_fernet_key(self):
@classmethod
@internal_api_call
@provide_session
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic:
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger:
"""Alternative constructor that creates a trigger row based directly off of a Trigger object."""
classpath, kwargs = trigger.serialize()
return cls(classpath=classpath, kwargs=kwargs)
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def __lt__(self, other):
def __hash__(self):
hash_components: list[Any] = [type(self)]
for c in _DAG_HASH_ATTRS:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict)
# If it is a list, convert to tuple because lists can't be hashed
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(getattr(self, c, None), list):
val = tuple(getattr(self, c))
else:
val = getattr(self, c, None)
try:
Expand Down