From 02e381915e710866019a97a7715b101d8742989c Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 18 Jan 2025 03:02:11 +0530 Subject: [PATCH] Use Protocol for `OutletEventAccessor` Follow-up of https://github.com/apache/airflow/pull/45727 to use Protocol to allow auto-completion on IDE while not introducing runtime dep --- airflow/models/taskinstance.py | 6 ++--- airflow/serialization/serialized_objects.py | 3 ++- airflow/utils/context.py | 3 ++- airflow/utils/operator_helpers.py | 4 +-- .../airflow/sdk/definitions/asset/__init__.py | 9 +++++++ .../src/airflow/sdk/definitions/context.py | 9 ++++--- .../src/airflow/sdk/definitions/protocols.py | 27 +++++++++++++++++++ .../src/airflow/sdk/execution_time/context.py | 10 +------ task_sdk/tests/execution_time/test_context.py | 9 +++++-- .../serialization/test_serialized_objects.py | 5 ++-- 10 files changed, 61 insertions(+), 24 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2f3fa4e8fb4a9a..01d862cfd9f66c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -163,7 +163,7 @@ from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol + from airflow.sdk.definitions.protocols import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup @@ -2730,7 +2730,7 @@ def _run_raw_task( ) def _register_asset_changes( - self, *, events: OutletEventAccessors, session: Session | None = None + self, *, events: OutletEventAccessorsProtocol, session: Session | None = None ) -> None: if session: TaskInstance._register_asset_changes_int(ti=self, events=events, session=session) @@ -2740,7 +2740,7 @@ def _register_asset_changes( @staticmethod @provide_session def _register_asset_changes_int( - ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION + ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION ) -> None: if TYPE_CHECKING: assert ti.task diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d828a9a5b6b241..a8f731f2748344 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -56,6 +56,7 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasEvent, AssetAliasUniqueKey, AssetAll, AssetAny, @@ -64,7 +65,7 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator -from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.sdk.execution_time.context import OutletEventAccessor from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 168243290fabc4..5d67f4a62f9bfa 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -63,6 +63,7 @@ from sqlalchemy.sql.expression import Select, TextClause from airflow.models.baseoperator import BaseOperator + from airflow.sdk.definitions.protocols import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: # * Context in task_sdk/src/airflow/sdk/definitions/context.py @@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context: return cast(Context, new) -def context_get_outlet_events(context: Context) -> OutletEventAccessors: +def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: try: return context["outlet_events"] except KeyError: diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index cb822aa1cc77b8..5a5cef2eac0a65 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -29,7 +29,7 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from airflow.utils.context import OutletEventAccessors + from airflow.sdk.definitions.protocols import OutletEventAccessorsProtocol P = ParamSpec("P") R = TypeVar("R") @@ -230,7 +230,7 @@ def run(*args, **kwargs): ... def ExecutionCallableRunner( func: Callable[P, R], - outlet_events: OutletEventAccessors, + outlet_events: OutletEventAccessorsProtocol, *, logger: logging.Logger, ) -> _ExecutionCallableRunner: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index ea89f1b681701a..51d4abbeda45df 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -660,3 +660,12 @@ def as_expression(self) -> Any: :meta private: """ return {"all": [o.as_expression() for o in self.objects]} + + +@attrs.define +class AssetAliasEvent: + """Representation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_key: AssetUniqueKey + extra: dict[str, Any] diff --git a/task_sdk/src/airflow/sdk/definitions/context.py b/task_sdk/src/airflow/sdk/definitions/context.py index 46a92ec2beb149..cb5213da2e2ed5 100644 --- a/task_sdk/src/airflow/sdk/definitions/context.py +++ b/task_sdk/src/airflow/sdk/definitions/context.py @@ -27,7 +27,11 @@ from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol + from airflow.sdk.definitions.protocols import ( + DagRunProtocol, + OutletEventAccessorsProtocol, + RuntimeTaskInstanceProtocol, + ) class Context(TypedDict, total=False): @@ -38,8 +42,7 @@ class Context(TypedDict, total=False): dag_run: DagRunProtocol data_interval_end: datetime | None data_interval_start: datetime | None - # outlet_events: OutletEventAccessors - outlet_events: Any + outlet_events: OutletEventAccessorsProtocol ds: str ds_nodash: str expanded_ti_count: int | None diff --git a/task_sdk/src/airflow/sdk/definitions/protocols.py b/task_sdk/src/airflow/sdk/definitions/protocols.py index 80dba602ff135a..35ee9f8e38cab5 100644 --- a/task_sdk/src/airflow/sdk/definitions/protocols.py +++ b/task_sdk/src/airflow/sdk/definitions/protocols.py @@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, Protocol if TYPE_CHECKING: + from collections.abc import Iterator from datetime import datetime + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey from airflow.sdk.definitions.baseoperator import BaseOperator @@ -65,3 +67,28 @@ def xcom_pull( ) -> Any: ... def xcom_push(self, key: str, value: Any) -> None: ... + + +class OutletEventAccessorProtocol(Protocol): + """Protocol for managing access to a specific outlet event accessor.""" + + key: BaseAssetUniqueKey + extra: dict[str, Any] + asset_alias_events: list[AssetAliasEvent] + + def __init__( + self, + *, + key: BaseAssetUniqueKey, + extra: dict[str, Any], + asset_alias_events: list[AssetAliasEvent], + ) -> None: ... + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... + + +class OutletEventAccessorsProtocol(Protocol): + """Protocol for managing access to outlet event accessors.""" + + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... + def __len__(self) -> int: ... + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessorProtocol: ... diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 918526c3004c2e..a068b53aec723d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -28,6 +28,7 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasEvent, AssetAliasUniqueKey, AssetNameRef, AssetRef, @@ -174,15 +175,6 @@ def __eq__(self, other: object) -> bool: return True -@attrs.define -class AssetAliasEvent: - """Representation of asset event to be triggered by an asset alias.""" - - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] - - @attrs.define class OutletEventAccessor: """Wrapper to access an outlet asset event in template.""" diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index e3ef15dc934cf6..a155f65a9f57b8 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -22,13 +22,18 @@ import pytest from airflow.sdk import get_current_context -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetAliasUniqueKey, + AssetUniqueKey, +) from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.context import ( - AssetAliasEvent, ConnectionAccessor, OutletEventAccessor, OutletEventAccessors, diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 707595b92ffa22..06bb477becdf4e 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -42,13 +42,12 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey -from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey +from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State