From 943ecd78a3b6fe059f209cd19cd79a5dad77fa51 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 17 Jan 2025 04:04:10 +0530 Subject: [PATCH] AIP-72: Add support for `outlet_events` in Task Context part of https://github.com/apache/airflow/issues/45717 This PR adds support for `outlet_events` in Context dict within the Task SDK by adding an endpoint on the API Server which is fetched when outlet_events is accessed. --- .../execution_api/datamodels/asset.py | 62 ++++++++ .../execution_api/routes/__init__.py | 10 +- .../execution_api/routes/assets.py | 79 +++++++++++ airflow/serialization/serialized_objects.py | 3 +- airflow/utils/context.py | 105 ++------------ task_sdk/src/airflow/sdk/api/client.py | 26 ++++ .../airflow/sdk/api/datamodels/_generated.py | 48 +++++++ .../src/airflow/sdk/execution_time/comms.py | 34 ++++- .../src/airflow/sdk/execution_time/context.py | 126 +++++++++++++++- .../airflow/sdk/execution_time/supervisor.py | 9 ++ .../airflow/sdk/execution_time/task_runner.py | 4 +- task_sdk/tests/execution_time/test_context.py | 75 ++++++++++ .../tests/execution_time/test_supervisor.py | 54 ++++++- .../tests/execution_time/test_task_runner.py | 9 +- .../execution_api/routes/test_assets.py | 134 ++++++++++++++++++ .../serialization/test_serialized_objects.py | 3 +- tests/utils/test_context.py | 3 +- 17 files changed, 679 insertions(+), 105 deletions(-) create mode 100644 airflow/api_fastapi/execution_api/datamodels/asset.py create mode 100644 airflow/api_fastapi/execution_api/routes/assets.py create mode 100644 tests/api_fastapi/execution_api/routes/test_assets.py diff --git a/airflow/api_fastapi/execution_api/datamodels/asset.py b/airflow/api_fastapi/execution_api/datamodels/asset.py new file mode 100644 index 00000000000000..36b744a54f9ec8 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/asset.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime + +from airflow.api_fastapi.core_api.base import BaseModel + + +class DagScheduleAssetReference(BaseModel): + """DAG schedule reference schema for assets.""" + + dag_id: str + created_at: datetime + updated_at: datetime + + +class TaskOutletAssetReference(BaseModel): + """Task outlet reference schema for assets.""" + + dag_id: str + task_id: str + created_at: datetime + updated_at: datetime + + +class AssetResponse(BaseModel): + """Asset schema for responses.""" + + id: int + name: str + uri: str + group: str + extra: dict | None = None + created_at: datetime + updated_at: datetime + consuming_dags: list[DagScheduleAssetReference] + producing_tasks: list[TaskOutletAssetReference] + aliases: list[AssetAliasResponse] + + +class AssetAliasResponse(BaseModel): + """Asset alias schema for responses.""" + + id: int + name: str + group: str diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow/api_fastapi/execution_api/routes/__init__.py index 0383503f18b874..793cd8fe084944 100644 --- a/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow/api_fastapi/execution_api/routes/__init__.py @@ -17,9 +17,17 @@ from __future__ import annotations from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms +from airflow.api_fastapi.execution_api.routes import ( + assets, + connections, + health, + task_instances, + variables, + xcoms, +) execution_api_router = AirflowRouter() +execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) execution_api_router.include_router(health.router, tags=["Health"]) execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) diff --git a/airflow/api_fastapi/execution_api/routes/assets.py b/airflow/api_fastapi/execution_api/routes/assets.py new file mode 100644 index 00000000000000..d2e21bd0c292a4 --- /dev/null +++ b/airflow/api_fastapi/execution_api/routes/assets.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Annotated + +from fastapi import HTTPException, Query, status +from sqlalchemy import select + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse +from airflow.models.asset import AssetModel + +# TODO: Add dependency on JWT token +router = AirflowRouter( + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) + + +@router.get( + "", + responses={ + status.HTTP_400_BAD_REQUEST: { + "description": "Either 'name' or 'uri' query parameter must be provided" + }, + status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, + }, +) +def get_asset( + session: SessionDep, + name: Annotated[str | None, Query(description="The name of the Asset")] = None, + uri: Annotated[str | None, Query(description="The URI of the Asset")] = None, +) -> AssetResponse: + """Get an Airflow Asset by `name` or `uri`.""" + if name: + asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has())) + _raise_if_not_found(asset, f"Asset with name {name} not found") + elif uri: + asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has())) + _raise_if_not_found(asset, f"Asset with URI {uri} not found") + else: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={ + "reason": "bad_request", + "message": "Either 'name' or 'uri' query parameter must be provided", + }, + ) + return AssetResponse.model_validate(asset) + + +def _raise_if_not_found(asset, msg): + if asset is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": msg, + }, + ) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 41a80ed5fc359e..21986ea3c2071a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -64,6 +64,7 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -77,10 +78,8 @@ from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import ( - AssetAliasEvent, ConnectionAccessor, Context, - OutletEventAccessor, OutletEventAccessors, VariableAccessor, ) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 1f453457e43235..5fb54d570943ce 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -19,8 +19,8 @@ from __future__ import annotations -import contextlib from collections.abc import ( + Callable, Container, Iterator, Mapping, @@ -51,9 +51,9 @@ AssetRef, AssetUniqueKey, AssetUriRef, - BaseAssetUniqueKey, ) from airflow.sdk.definitions.context import Context +from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK from airflow.utils.db import LazySelectSequence from airflow.utils.session import create_session from airflow.utils.types import NOTSET @@ -156,104 +156,27 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn -@attrs.define() -class AssetAliasEvent: - """ - Represeation of asset event to be triggered by an asset alias. - - :meta private: - """ +def _get_asset(name: str | None = None, uri: str | None = None) -> Asset: + if name: + with create_session() as session: + asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has())) + elif uri: + with create_session() as session: + asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has())) + else: + raise ValueError("Either name or uri must be provided") - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] + return asset.to_public() -@attrs.define() -class OutletEventAccessor: - """ - Wrapper to access an outlet asset event in template. - - :meta private: - """ - - key: BaseAssetUniqueKey - extra: dict[str, Any] = attrs.Factory(dict) - asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: - """Add an AssetEvent to an existing Asset.""" - if not isinstance(self.key, AssetAliasUniqueKey): - return - - asset_alias_name = self.key.name - event = AssetAliasEvent( - source_alias_name=asset_alias_name, - dest_asset_key=AssetUniqueKey.from_asset(asset), - extra=extra or {}, - ) - self.asset_alias_events.append(event) - - -class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): +class OutletEventAccessors(OutletEventAccessorsSDK): """ Lazy mapping of outlet asset event accessors. :meta private: """ - _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} - - def __init__(self) -> None: - self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} - - def __str__(self) -> str: - return f"OutletEventAccessors(_dict={self._dict})" - - def __iter__(self) -> Iterator[Asset | AssetAlias]: - return ( - key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict - ) - - def __len__(self) -> int: - return len(self._dict) - - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: - hashable_key: BaseAssetUniqueKey - if isinstance(key, Asset): - hashable_key = AssetUniqueKey.from_asset(key) - elif isinstance(key, AssetAlias): - hashable_key = AssetAliasUniqueKey.from_asset_alias(key) - elif isinstance(key, AssetRef): - hashable_key = self._resolve_asset_ref(key) - else: - raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") - - if hashable_key not in self._dict: - self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) - return self._dict[hashable_key] - - def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: - with contextlib.suppress(KeyError): - return self._asset_ref_cache[ref] - - refs_to_cache: list[AssetRef] - with create_session() as session: - if isinstance(ref, AssetNameRef): - asset = session.scalar( - select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has()) - ) - refs_to_cache = [ref, AssetUriRef(asset.uri)] - elif isinstance(ref, AssetUriRef): - asset = session.scalar( - select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has()) - ) - refs_to_cache = [ref, AssetNameRef(asset.name)] - else: - raise TypeError(f"Unimplemented asset ref: {type(ref)}") - for ref in refs_to_cache: - self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset) - return unique_key + _get_asset_func: Callable[..., Asset] = _get_asset class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 5ee270591481e3..5ea23de9db21ab 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -34,6 +34,7 @@ from airflow.sdk import __version__ from airflow.sdk.api.datamodels._generated import ( + AssetResponse, ConnectionResponse, DagRunType, TerminalTIState, @@ -267,6 +268,25 @@ def set( return {"ok": True} +class AssetOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse: + """Get Asset value from the API server.""" + if name: + params = {"name": name} + elif uri: + params = {"uri": uri} + else: + raise ValueError("Either `name` or `uri` must be provided") + + resp = self.client.get("assets/", params=params) + return AssetResponse.model_validate_json(resp.read()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -374,6 +394,12 @@ def xcoms(self) -> XComOperations: """Operations related to XComs.""" return XComOperations(self) + @lru_cache() # type: ignore[misc] + @property + def assets(self) -> AssetOperations: + """Operations related to XComs.""" + return AssetOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index a8b478d07f029a..6e5a07af179e29 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -29,6 +29,16 @@ from pydantic import BaseModel, ConfigDict, Field +class AssetAliasResponse(BaseModel): + """ + Asset alias schema for responses. + """ + + id: Annotated[int, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + group: Annotated[str, Field(title="Group")] + + class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. @@ -55,6 +65,16 @@ class DagRunType(str, Enum): ASSET_TRIGGERED = "asset_triggered" +class DagScheduleAssetReference(BaseModel): + """ + DAG schedule reference schema for assets. + """ + + dag_id: Annotated[str, Field(title="Dag Id")] + created_at: Annotated[datetime, Field(title="Created At")] + updated_at: Annotated[datetime, Field(title="Updated At")] + + class IntermediateTIState(str, Enum): """ States that a Task Instance can be in that indicate it is not yet in a terminal or running state. @@ -120,6 +140,17 @@ class TITargetStatePayload(BaseModel): state: IntermediateTIState +class TaskOutletAssetReference(BaseModel): + """ + Task outlet reference schema for assets. + """ + + dag_id: Annotated[str, Field(title="Dag Id")] + task_id: Annotated[str, Field(title="Task Id")] + created_at: Annotated[datetime, Field(title="Created At")] + updated_at: Annotated[datetime, Field(title="Updated At")] + + class TerminalTIState(str, Enum): """ States that a Task Instance can be in that indicate it has reached a terminal state. @@ -187,6 +218,23 @@ class TaskInstance(BaseModel): hostname: Annotated[str | None, Field(title="Hostname")] = None +class AssetResponse(BaseModel): + """ + Asset schema for responses. + """ + + id: Annotated[int, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + uri: Annotated[str, Field(title="Uri")] + group: Annotated[str, Field(title="Group")] + extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + created_at: Annotated[datetime, Field(title="Created At")] + updated_at: Annotated[datetime, Field(title="Updated At")] + consuming_dags: Annotated[list[DagScheduleAssetReference], Field(title="Consuming Dags")] + producing_tasks: Annotated[list[TaskOutletAssetReference], Field(title="Producing Tasks")] + aliases: Annotated[list[AssetAliasResponse], Field(title="Aliases")] + + class DagRun(BaseModel): """ Schema for DagRun model with minimal required fields needed for Runtime. diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index b6874d47f090cd..f8aaab65af4f13 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -50,6 +50,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue from airflow.sdk.api.datamodels._generated import ( + AssetResponse, BundleInfo, ConnectionResponse, TaskInstance, @@ -79,6 +80,25 @@ class StartupDetails(BaseModel): type: Literal["StartupDetails"] = "StartupDetails" +class AssetResult(AssetResponse): + """Response to ReadXCom request.""" + + type: Literal["AssetResult"] = "AssetResult" + + @classmethod + def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult: + """ + Get AssetResult from AssetResponse. + + AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult + for communication between the Supervisor and the task process. + """ + # Exclude defaults to avoid sending unnecessary data + # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True + # to avoid sending unset fields (which are defaults in our case). + return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult") + + class XComResult(XComResponse): """Response to ReadXCom request.""" @@ -133,7 +153,7 @@ class ErrorResponse(BaseModel): ToTask = Annotated[ - Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse], + Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse, AssetResult], Field(discriminator="type"), ] @@ -231,12 +251,24 @@ class SetRenderedFields(BaseModel): type: Literal["SetRenderedFields"] = "SetRenderedFields" +class GetAssetByName(BaseModel): + name: str + type: Literal["GetAssetByName"] = "GetAssetByName" + + +class GetAssetByUri(BaseModel): + uri: str + type: Literal["GetAssetByUri"] = "GetAssetByUri" + + ToSupervisor = Annotated[ Union[ TaskState, GetXCom, GetConnection, GetVariable, + GetAssetByName, + GetAssetByUri, DeferTask, PutVariable, SetXCom, diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index cdb3880bb36b33..08b68300b1b529 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -17,20 +17,31 @@ from __future__ import annotations import contextlib -from collections.abc import Generator -from typing import TYPE_CHECKING, Any +from collections.abc import Callable, Generator, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Union +import attrs import structlog from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT from airflow.sdk.definitions._internal.types import NOTSET +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasUniqueKey, + AssetNameRef, + AssetRef, + AssetUniqueKey, + AssetUriRef, + BaseAssetUniqueKey, +) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable - from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult + from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, VariableResult log = structlog.get_logger(logger_name="task") @@ -163,6 +174,115 @@ def __eq__(self, other: object) -> bool: return True +@attrs.define +class AssetAliasEvent: + """Representation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_key: AssetUniqueKey + extra: dict[str, Any] + + +@attrs.define +class OutletEventAccessor: + """Wrapper to access an outlet asset event in template.""" + + key: BaseAssetUniqueKey + extra: dict[str, Any] = attrs.Factory(dict) + asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) + + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: + """Add an AssetEvent to an existing Asset.""" + if not isinstance(self.key, AssetAliasUniqueKey): + return + + asset_alias_name = self.key.name + event = AssetAliasEvent( + source_alias_name=asset_alias_name, + dest_asset_key=AssetUniqueKey.from_asset(asset), + extra=extra or {}, + ) + self.asset_alias_events.append(event) + + +def _get_asset(name: str | None = None, uri: str | None = None) -> Asset: + from airflow.sdk.definitions.asset import Asset + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + if name: + msg = GetAssetByName(name=name) + elif uri: + msg = GetAssetByUri(uri=uri) + else: + raise ValueError("Either name or uri must be provided") + + SUPERVISOR_COMMS.send_request(log=log, msg=msg) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, AssetResult) + return Asset(**msg.model_dump(exclude={"type"})) + + +class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): + """Lazy mapping of outlet asset event accessors.""" + + _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} + + # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py + _get_asset_func: Callable[..., Asset] = _get_asset + + def __init__(self) -> None: + self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} + + def __str__(self) -> str: + return f"OutletEventAccessors(_dict={self._dict})" + + def __iter__(self) -> Iterator[Asset | AssetAlias]: + return ( + key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict + ) + + def __len__(self) -> int: + return len(self._dict) + + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: + hashable_key: BaseAssetUniqueKey + if isinstance(key, Asset): + hashable_key = AssetUniqueKey.from_asset(key) + elif isinstance(key, AssetAlias): + hashable_key = AssetAliasUniqueKey.from_asset_alias(key) + elif isinstance(key, AssetRef): + hashable_key = self._resolve_asset_ref(key) + else: + raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") + + if hashable_key not in self._dict: + self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) + return self._dict[hashable_key] + + def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: + with contextlib.suppress(KeyError): + return self._asset_ref_cache[ref] + + refs_to_cache: list[AssetRef] + if isinstance(ref, AssetNameRef): + asset = self._get_asset_func(name=ref.name) + refs_to_cache = [ref, AssetUriRef(asset.uri)] + elif isinstance(ref, AssetUriRef): + asset = self._get_asset_func(uri=ref.uri) + refs_to_cache = [ref, AssetNameRef(asset.name)] + else: + raise TypeError(f"Unimplemented asset ref: {type(ref)}") + unique_key = AssetUniqueKey.from_asset(asset) + for ref in refs_to_cache: + self._asset_ref_cache[ref] = unique_key + return unique_key + + @contextlib.contextmanager def set_current_context(context: Context) -> Generator[Context, None, None]: """ diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 32895d36524d84..037a62efcc025d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -61,8 +61,11 @@ VariableResponse, ) from airflow.sdk.execution_time.comms import ( + AssetResult, ConnectionResult, DeferTask, + GetAssetByName, + GetAssetByUri, GetConnection, GetVariable, GetXCom, @@ -787,6 +790,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, SetRenderedFields): self.client.task_instances.set_rtif(self.id, msg.rendered_fields) + elif isinstance(msg, GetAssetByName): + asset_resp = self.client.assets.get(name=msg.name) + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, GetAssetByUri): + self.client.assets.get(name=msg.uri) else: log.error("Unhandled request", msg=msg) return diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 186faac878a0a7..d252c24be180c0 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -50,6 +50,7 @@ from airflow.sdk.execution_time.context import ( ConnectionAccessor, MacrosAccessor, + OutletEventAccessors, VariableAccessor, set_current_context, ) @@ -92,12 +93,13 @@ def get_template_context(self) -> Context: # TODO: Ensure that ti.log_url and such are available to use in context # especially after removal of `conf` from Context. "ti": self, - # "outlet_events": OutletEventAccessors(), + "outlet_events": OutletEventAccessors(), # "expanded_ti_count": expanded_ti_count, "expanded_ti_count": None, # TODO: Implement this # "inlet_events": InletEventsAccessors(task.inlets, session=session), "macros": MacrosAccessor(), # "params": validated_params, + # TODO: Make this go through Public API longer term. # "prev_data_interval_start_success": get_prev_data_interval_start_success(), # "prev_data_interval_end_success": get_prev_data_interval_end_success(), # "prev_start_date_success": get_prev_start_date_success(), diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 6527d517e375f4..f1e82fdfdf1681 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -22,12 +22,16 @@ import pytest from airflow.sdk import get_current_context +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.context import ( + AssetAliasEvent, ConnectionAccessor, + OutletEventAccessor, + OutletEventAccessors, VariableAccessor, _convert_connection_result_conn, _convert_variable_result_to_variable, @@ -248,3 +252,74 @@ def test_nested_context(self): assert ctx["ContextId"] == i # End of with statement ctx_list[i].__exit__(None, None, None) + + +class TestOutletEventAccessor: + @pytest.mark.parametrize( + "key, asset_alias_events", + ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"), + extra={}, + ) + ], + ), + ), + ) + def test_add(self, key, asset_alias_events, mock_supervisor_comms): + asset = Asset("test_uri") + mock_supervisor_comms.get_message.return_value = asset + + outlet_event_accessor = OutletEventAccessor(key=key, extra={}) + outlet_event_accessor.add(asset) + assert outlet_event_accessor.asset_alias_events == asset_alias_events + + @pytest.mark.parametrize( + "key, asset_alias_events", + ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), + extra={}, + ) + ], + ), + ), + ) + def test_add_with_db(self, key, asset_alias_events, mock_supervisor_comms): + asset = Asset(uri="test://asset-uri", name="test-asset") + mock_supervisor_comms.get_message.return_value = asset + + outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) + outlet_event_accessor.add(asset, extra={}) + assert outlet_event_accessor.asset_alias_events == asset_alias_events + + +class TestOutletEventAccessors: + @pytest.mark.parametrize( + "access_key, internal_key", + ( + (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), + ( + Asset(name="test", uri="test://asset"), + AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")), + ), + (AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))), + ), + ) + def test___get_item__dict_key_not_exists(self, access_key, internal_key): + outlet_event_accessors = OutletEventAccessors() + assert len(outlet_event_accessors) == 0 + outlet_event_accessor = outlet_event_accessors[access_key] + assert len(outlet_event_accessors) == 1 + assert outlet_event_accessor.key == internal_key + assert outlet_event_accessor.extra == {} diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 12c3455ccfe1e6..ba39dc4a1489e0 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -41,8 +41,10 @@ from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( + AssetResult, ConnectionResult, DeferTask, + GetAssetByName, GetConnection, GetVariable, GetXCom, @@ -805,13 +807,14 @@ def watched_subprocess(self, mocker): ) @pytest.mark.parametrize( - ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], + ["message", "expected_buffer", "client_attr_path", "method_arg", "method_kwarg", "mock_response"], [ pytest.param( GetConnection(conn_id="test_conn"), b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', "connections.get", ("test_conn",), + {}, ConnectionResult(conn_id="test_conn", conn_type="mysql"), id="get_connection", ), @@ -820,6 +823,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', "variables.get", ("test_key",), + {}, VariableResult(key="test_key", value="test_value"), id="get_variable", ), @@ -828,6 +832,7 @@ def watched_subprocess(self, mocker): b"", "variables.set", ("test_key", "test_value", "test_description"), + {}, {"ok": True}, id="set_variable", ), @@ -836,6 +841,7 @@ def watched_subprocess(self, mocker): b"", "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + {}, "", id="patch_task_instance_to_deferred", ), @@ -853,6 +859,7 @@ def watched_subprocess(self, mocker): end_date=timezone.parse("2024-10-31T12:00:00Z"), ), ), + {}, "", id="patch_task_instance_to_up_for_reschedule", ), @@ -861,6 +868,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None), + {}, XComResult(key="test_key", value="test_value"), id="get_xcom", ), @@ -871,6 +879,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2), + {}, XComResult(key="test_key", value="test_value"), id="get_xcom_map_index", ), @@ -879,6 +888,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":null,"type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None), + {}, XComResult(key="test_key", value=None, type="XComResult"), id="get_xcom_not_found", ), @@ -900,6 +910,7 @@ def watched_subprocess(self, mocker): '{"key": "test_key", "value": {"key2": "value2"}}', None, ), + {}, {"ok": True}, id="set_xcom", ), @@ -922,6 +933,7 @@ def watched_subprocess(self, mocker): '{"key": "test_key", "value": {"key2": "value2"}}', 2, ), + {}, {"ok": True}, id="set_xcom_with_map_index", ), @@ -932,6 +944,7 @@ def watched_subprocess(self, mocker): b"", "", (), + {}, "", id="patch_task_instance_to_skipped", ), @@ -940,9 +953,44 @@ def watched_subprocess(self, mocker): b"", "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), + {}, {"ok": True}, id="set_rtif", ), + pytest.param( + GetAssetByName(name="test_asset"), + AssetResult( + id=1, + name="test_asset", + uri="s3://test_bucket/test_asset", + group="asset", + extra={"foo": "bar"}, + created_at=timezone.parse("2021-01-01T00:00:00"), + updated_at=timezone.parse("2021-01-01T00:00:00"), + consuming_dags=[], + producing_tasks=[], + aliases=[], + ) + .model_dump_json() + .encode() + + b"\n", + "assets.get", + [], + {"name": "test_asset"}, + AssetResult( + id=1, + name="test_asset", + uri="s3://test_bucket/test_asset", + group="asset", + extra={"foo": "bar"}, + created_at=timezone.parse("2021-01-01T00:00:00"), + updated_at=timezone.parse("2021-01-01T00:00:00"), + consuming_dags=[], + producing_tasks=[], + aliases=[], + ), + id="get_asset_by_name", + ), ], ) def test_handle_requests( @@ -953,8 +1001,8 @@ def test_handle_requests( expected_buffer, client_attr_path, method_arg, + method_kwarg, mock_response, - time_machine, ): """ Test handling of different messages to the subprocess. For any new message type, add a @@ -980,7 +1028,7 @@ def test_handle_requests( # Verify the correct client method was called if client_attr_path: - mock_client_method.assert_called_once_with(*method_arg) + mock_client_method.assert_called_once_with(*method_arg, **method_kwarg) # Verify the response was added to the buffer val = watched_subprocess.stdin.getvalue() diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index f716317ad24fcf..0b7d762e3404c6 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -52,7 +52,12 @@ VariableResult, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ( + ConnectionAccessor, + MacrosAccessor, + OutletEventAccessors, + VariableAccessor, +) from airflow.sdk.execution_time.task_runner import ( CommsDecoder, RuntimeTaskInstance, @@ -613,6 +618,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ "inlets": task.inlets, "macros": MacrosAccessor(), "map_index_template": task.map_index_template, + "outlet_events": OutletEventAccessors(), "outlets": task.outlets, "run_id": "test_run", "task": task, @@ -645,6 +651,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti): "inlets": task.inlets, "macros": MacrosAccessor(), "map_index_template": task.map_index_template, + "outlet_events": OutletEventAccessors(), "outlets": task.outlets, "run_id": "test_run", "task": task, diff --git a/tests/api_fastapi/execution_api/routes/test_assets.py b/tests/api_fastapi/execution_api/routes/test_assets.py new file mode 100644 index 00000000000000..000cf12973daf0 --- /dev/null +++ b/tests/api_fastapi/execution_api/routes/test_assets.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.models.asset import AssetActive, AssetModel +from airflow.utils import timezone + +DEFAULT_DATE = timezone.parse("2021-01-01T00:00:00") + +pytestmark = pytest.mark.db_test + + +class TestGetAsset: + def test_get_asset_by_name(self, client, session): + asset = AssetModel( + id=1, + name="test_get_asset_by_name", + uri="s3://bucket/key", + group="asset", + extra={"foo": "bar"}, + created_at=DEFAULT_DATE, + updated_at=DEFAULT_DATE, + ) + + asset_active = AssetActive.for_asset(asset) + + session.add_all([asset, asset_active]) + session.commit() + + response = client.get("/execution/assets/", params={"name": "test_get_asset_by_name"}) + + assert response.status_code == 200 + assert response.json() == { + "id": 1, + "name": "test_get_asset_by_name", + "uri": "s3://bucket/key", + "group": "asset", + "extra": {"foo": "bar"}, + "created_at": DEFAULT_DATE.to_iso8601_string(), + "updated_at": DEFAULT_DATE.to_iso8601_string(), + "consuming_dags": [], + "producing_tasks": [], + "aliases": [], + } + + session.delete(asset) + session.delete(asset_active) + session.commit() + + def test_get_asset_by_uri(self, client, session): + asset = AssetModel( + id=2, + name="test_get_asset_by_uri", + uri="s3://bucket/key", + group="asset", + extra={"foo": "bar"}, + created_at=DEFAULT_DATE, + updated_at=DEFAULT_DATE, + ) + + asset_active = AssetActive.for_asset(asset) + + session.add_all([asset, asset_active]) + session.commit() + + response = client.get("/execution/assets/", params={"uri": "s3://bucket/key"}) + + assert response.status_code == 200 + assert response.json() == { + "id": 2, + "name": "test_get_asset_by_uri", + "uri": "s3://bucket/key", + "group": "asset", + "extra": {"foo": "bar"}, + "created_at": DEFAULT_DATE.to_iso8601_string(), + "updated_at": DEFAULT_DATE.to_iso8601_string(), + "consuming_dags": [], + "producing_tasks": [], + "aliases": [], + } + + session.delete(asset) + session.delete(asset_active) + session.commit() + + def test_asset_name_not_found(self, client): + response = client.get("/execution/assets/", params={"name": "non_existent"}) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Asset with name non_existent not found", + "reason": "not_found", + } + } + + def test_asset_uri_not_found(self, client): + response = client.get("/execution/assets/", params={"uri": "non_existent"}) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Asset with URI non_existent not found", + "reason": "not_found", + } + } + + def test_asset_bad_request(self, client): + response = client.get("/execution/assets/") + + assert response.status_code == 400 + assert response.json() == { + "detail": { + "message": "Either 'name' or 'uri' query parameter must be provided", + "reason": "bad_request", + } + } diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 0faeed038e648f..707595b92ffa22 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -43,11 +43,12 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors +from airflow.utils.context import OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 0046ca33cc4da8..783ff63d82e172 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -22,7 +22,8 @@ from airflow.models.asset import AssetActive, AssetAliasModel, AssetModel from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey -from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor +from airflow.utils.context import OutletEventAccessors class TestOutletEventAccessor: