From 38d9dba8a7110759e96543d79eb6a801ad38f436 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Mon, 20 Jan 2025 16:14:23 +0530 Subject: [PATCH] Use Protocol for `OutletEventAccessor` (#45762) Follow-up of https://github.com/apache/airflow/pull/45727 to use Protocol to allow auto-completion on IDE while not introducing runtime dep --- airflow/models/taskinstance.py | 6 ++--- airflow/serialization/serialized_objects.py | 4 +-- airflow/utils/context.py | 3 ++- airflow/utils/operator_helpers.py | 4 +-- .../providers/edge/example_dags/win_test.py | 2 +- .../amazon/aws/transfers/google_api_to_s3.py | 2 +- .../airflow/sdk/definitions/asset/__init__.py | 9 +++++++ .../src/airflow/sdk/definitions/context.py | 9 ++++--- .../src/airflow/sdk/execution_time/context.py | 10 +------ .../airflow/sdk/execution_time/task_runner.py | 1 - .../{definitions/protocols.py => types.py} | 27 +++++++++++++++++++ task_sdk/tests/execution_time/test_context.py | 9 +++++-- .../serialization/test_serialized_objects.py | 5 ++-- 13 files changed, 63 insertions(+), 28 deletions(-) rename task_sdk/src/airflow/sdk/{definitions/protocols.py => types.py} (69%) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2f3fa4e8fb4a9a..5e0f4001d2d03c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -163,7 +163,7 @@ from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol + from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup @@ -2730,7 +2730,7 @@ def _run_raw_task( ) def _register_asset_changes( - self, *, events: OutletEventAccessors, session: Session | None = None + self, *, events: OutletEventAccessorsProtocol, session: Session | None = None ) -> None: if session: TaskInstance._register_asset_changes_int(ti=self, events=events, session=session) @@ -2740,7 +2740,7 @@ def _register_asset_changes( @staticmethod @provide_session def _register_asset_changes_int( - ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION + ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION ) -> None: if TYPE_CHECKING: assert ti.task diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d828a9a5b6b241..88e0f200bb24d8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -56,6 +56,7 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasEvent, AssetAliasUniqueKey, AssetAll, AssetAny, @@ -64,7 +65,7 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator -from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -80,7 +81,6 @@ from airflow.utils.context import ( ConnectionAccessor, Context, - OutletEventAccessors, VariableAccessor, ) from airflow.utils.db import LazySelectSequence diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 168243290fabc4..a36202f0793ec0 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -63,6 +63,7 @@ from sqlalchemy.sql.expression import Select, TextClause from airflow.models.baseoperator import BaseOperator + from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: # * Context in task_sdk/src/airflow/sdk/definitions/context.py @@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context: return cast(Context, new) -def context_get_outlet_events(context: Context) -> OutletEventAccessors: +def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: try: return context["outlet_events"] except KeyError: diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index cb822aa1cc77b8..ab3c8c89e5e0f2 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -29,7 +29,7 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from airflow.utils.context import OutletEventAccessors + from airflow.sdk.types import OutletEventAccessorsProtocol P = ParamSpec("P") R = TypeVar("R") @@ -230,7 +230,7 @@ def run(*args, **kwargs): ... def ExecutionCallableRunner( func: Callable[P, R], - outlet_events: OutletEventAccessors, + outlet_events: OutletEventAccessorsProtocol, *, logger: logging.Logger, ) -> _ExecutionCallableRunner: diff --git a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py index a2727363d6401c..15735b85d18323 100644 --- a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py +++ b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: try: - from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol as TaskInstance + from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance except ImportError: from airflow.models import TaskInstance # type: ignore[assignment] from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index a3d6bd619ce4b9..157477341b44d3 100644 --- a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: try: - from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol + from airflow.sdk.types import RuntimeTaskInstanceProtocol except ImportError: from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol # type: ignore[assignment] from airflow.utils.context import Context diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index ea89f1b681701a..51d4abbeda45df 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -660,3 +660,12 @@ def as_expression(self) -> Any: :meta private: """ return {"all": [o.as_expression() for o in self.objects]} + + +@attrs.define +class AssetAliasEvent: + """Representation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_key: AssetUniqueKey + extra: dict[str, Any] diff --git a/task_sdk/src/airflow/sdk/definitions/context.py b/task_sdk/src/airflow/sdk/definitions/context.py index 46a92ec2beb149..b98c1a2e0489cc 100644 --- a/task_sdk/src/airflow/sdk/definitions/context.py +++ b/task_sdk/src/airflow/sdk/definitions/context.py @@ -27,7 +27,11 @@ from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol + from airflow.sdk.types import ( + DagRunProtocol, + OutletEventAccessorsProtocol, + RuntimeTaskInstanceProtocol, + ) class Context(TypedDict, total=False): @@ -38,8 +42,7 @@ class Context(TypedDict, total=False): dag_run: DagRunProtocol data_interval_end: datetime | None data_interval_start: datetime | None - # outlet_events: OutletEventAccessors - outlet_events: Any + outlet_events: OutletEventAccessorsProtocol ds: str ds_nodash: str expanded_ti_count: int | None diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 918526c3004c2e..a068b53aec723d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -28,6 +28,7 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasEvent, AssetAliasUniqueKey, AssetNameRef, AssetRef, @@ -174,15 +175,6 @@ def __eq__(self, other: object) -> bool: return True -@attrs.define -class AssetAliasEvent: - """Representation of asset event to be triggered by an asset alias.""" - - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] - - @attrs.define class OutletEventAccessor: """Wrapper to access an outlet asset event in template.""" diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index d252c24be180c0..d4816c8ae59f6a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -137,7 +137,6 @@ def get_template_context(self) -> Context: } context.update(context_from_server) - # TODO: We should use/move TypeDict from airflow.utils.context.Context return context def render_templates( diff --git a/task_sdk/src/airflow/sdk/definitions/protocols.py b/task_sdk/src/airflow/sdk/types.py similarity index 69% rename from task_sdk/src/airflow/sdk/definitions/protocols.py rename to task_sdk/src/airflow/sdk/types.py index 80dba602ff135a..35ee9f8e38cab5 100644 --- a/task_sdk/src/airflow/sdk/definitions/protocols.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, Protocol if TYPE_CHECKING: + from collections.abc import Iterator from datetime import datetime + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey from airflow.sdk.definitions.baseoperator import BaseOperator @@ -65,3 +67,28 @@ def xcom_pull( ) -> Any: ... def xcom_push(self, key: str, value: Any) -> None: ... + + +class OutletEventAccessorProtocol(Protocol): + """Protocol for managing access to a specific outlet event accessor.""" + + key: BaseAssetUniqueKey + extra: dict[str, Any] + asset_alias_events: list[AssetAliasEvent] + + def __init__( + self, + *, + key: BaseAssetUniqueKey, + extra: dict[str, Any], + asset_alias_events: list[AssetAliasEvent], + ) -> None: ... + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... + + +class OutletEventAccessorsProtocol(Protocol): + """Protocol for managing access to outlet event accessors.""" + + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... + def __len__(self) -> int: ... + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessorProtocol: ... diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index e3ef15dc934cf6..a155f65a9f57b8 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -22,13 +22,18 @@ import pytest from airflow.sdk import get_current_context -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetAliasUniqueKey, + AssetUniqueKey, +) from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.context import ( - AssetAliasEvent, ConnectionAccessor, OutletEventAccessor, OutletEventAccessors, diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 707595b92ffa22..06bb477becdf4e 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -42,13 +42,12 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey -from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey +from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State