diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2f3fa4e8fb4a9..01d862cfd9f66 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -163,7 +163,7 @@ from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol + from airflow.sdk.definitions.protocols import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup @@ -2730,7 +2730,7 @@ def _run_raw_task( ) def _register_asset_changes( - self, *, events: OutletEventAccessors, session: Session | None = None + self, *, events: OutletEventAccessorsProtocol, session: Session | None = None ) -> None: if session: TaskInstance._register_asset_changes_int(ti=self, events=events, session=session) @@ -2740,7 +2740,7 @@ def _register_asset_changes( @staticmethod @provide_session def _register_asset_changes_int( - ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION + ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION ) -> None: if TYPE_CHECKING: assert ti.task diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d828a9a5b6b24..88e0f200bb24d 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -56,6 +56,7 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasEvent, AssetAliasUniqueKey, AssetAll, AssetAny, @@ -64,7 +65,7 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator -from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -80,7 +81,6 @@ from airflow.utils.context import ( ConnectionAccessor, Context, - OutletEventAccessors, VariableAccessor, ) from airflow.utils.db import LazySelectSequence diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 168243290fabc..5d67f4a62f9bf 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -63,6 +63,7 @@ from sqlalchemy.sql.expression import Select, TextClause from airflow.models.baseoperator import BaseOperator + from airflow.sdk.definitions.protocols import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: # * Context in task_sdk/src/airflow/sdk/definitions/context.py @@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context: return cast(Context, new) -def context_get_outlet_events(context: Context) -> OutletEventAccessors: +def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: try: return context["outlet_events"] except KeyError: diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index cb822aa1cc77b..5a5cef2eac0a6 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -29,7 +29,7 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from airflow.utils.context import OutletEventAccessors + from airflow.sdk.definitions.protocols import OutletEventAccessorsProtocol P = ParamSpec("P") R = TypeVar("R") @@ -230,7 +230,7 @@ def run(*args, **kwargs): ... def ExecutionCallableRunner( func: Callable[P, R], - outlet_events: OutletEventAccessors, + outlet_events: OutletEventAccessorsProtocol, *, logger: logging.Logger, ) -> _ExecutionCallableRunner: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index ea89f1b681701..51d4abbeda45d 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -660,3 +660,12 @@ def as_expression(self) -> Any: :meta private: """ return {"all": [o.as_expression() for o in self.objects]} + + +@attrs.define +class AssetAliasEvent: + """Representation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_key: AssetUniqueKey + extra: dict[str, Any] diff --git a/task_sdk/src/airflow/sdk/definitions/context.py b/task_sdk/src/airflow/sdk/definitions/context.py index 46a92ec2beb14..cb5213da2e2ed 100644 --- a/task_sdk/src/airflow/sdk/definitions/context.py +++ b/task_sdk/src/airflow/sdk/definitions/context.py @@ -27,7 +27,11 @@ from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol + from airflow.sdk.definitions.protocols import ( + DagRunProtocol, + OutletEventAccessorsProtocol, + RuntimeTaskInstanceProtocol, + ) class Context(TypedDict, total=False): @@ -38,8 +42,7 @@ class Context(TypedDict, total=False): dag_run: DagRunProtocol data_interval_end: datetime | None data_interval_start: datetime | None - # outlet_events: OutletEventAccessors - outlet_events: Any + outlet_events: OutletEventAccessorsProtocol ds: str ds_nodash: str expanded_ti_count: int | None diff --git a/task_sdk/src/airflow/sdk/definitions/protocols.py b/task_sdk/src/airflow/sdk/definitions/protocols.py index 80dba602ff135..35ee9f8e38cab 100644 --- a/task_sdk/src/airflow/sdk/definitions/protocols.py +++ b/task_sdk/src/airflow/sdk/definitions/protocols.py @@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, Protocol if TYPE_CHECKING: + from collections.abc import Iterator from datetime import datetime + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey from airflow.sdk.definitions.baseoperator import BaseOperator @@ -65,3 +67,28 @@ def xcom_pull( ) -> Any: ... def xcom_push(self, key: str, value: Any) -> None: ... + + +class OutletEventAccessorProtocol(Protocol): + """Protocol for managing access to a specific outlet event accessor.""" + + key: BaseAssetUniqueKey + extra: dict[str, Any] + asset_alias_events: list[AssetAliasEvent] + + def __init__( + self, + *, + key: BaseAssetUniqueKey, + extra: dict[str, Any], + asset_alias_events: list[AssetAliasEvent], + ) -> None: ... + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... + + +class OutletEventAccessorsProtocol(Protocol): + """Protocol for managing access to outlet event accessors.""" + + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... + def __len__(self) -> int: ... + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessorProtocol: ... diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 918526c3004c2..a068b53aec723 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -28,6 +28,7 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasEvent, AssetAliasUniqueKey, AssetNameRef, AssetRef, @@ -174,15 +175,6 @@ def __eq__(self, other: object) -> bool: return True -@attrs.define -class AssetAliasEvent: - """Representation of asset event to be triggered by an asset alias.""" - - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] - - @attrs.define class OutletEventAccessor: """Wrapper to access an outlet asset event in template.""" diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index e3ef15dc934cf..a155f65a9f57b 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -22,13 +22,18 @@ import pytest from airflow.sdk import get_current_context -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetAliasUniqueKey, + AssetUniqueKey, +) from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.context import ( - AssetAliasEvent, ConnectionAccessor, OutletEventAccessor, OutletEventAccessors, diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 707595b92ffa2..06bb477becdf4 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -42,13 +42,12 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey -from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey +from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State