diff --git a/airflow/api_fastapi/execution_api/datamodels/asset.py b/airflow/api_fastapi/execution_api/datamodels/asset.py new file mode 100644 index 0000000000000..6d3a53c3e4ca8 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/asset.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.api_fastapi.core_api.base import BaseModel + + +class AssetResponse(BaseModel): + """Asset schema for responses with fields that are needed for Runtime.""" + + name: str + uri: str + group: str + extra: dict | None = None + + +class AssetAliasResponse(BaseModel): + """Asset alias schema with fields that are needed for Runtime.""" + + name: str + group: str diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow/api_fastapi/execution_api/routes/__init__.py index 0383503f18b87..793cd8fe08494 100644 --- a/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow/api_fastapi/execution_api/routes/__init__.py @@ -17,9 +17,17 @@ from __future__ import annotations from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms +from airflow.api_fastapi.execution_api.routes import ( + assets, + connections, + health, + task_instances, + variables, + xcoms, +) execution_api_router = AirflowRouter() +execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) execution_api_router.include_router(health.router, tags=["Health"]) execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) diff --git a/airflow/api_fastapi/execution_api/routes/assets.py b/airflow/api_fastapi/execution_api/routes/assets.py new file mode 100644 index 0000000000000..213c599befb3e --- /dev/null +++ b/airflow/api_fastapi/execution_api/routes/assets.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Annotated + +from fastapi import HTTPException, Query, status +from sqlalchemy import select + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse +from airflow.models.asset import AssetModel + +# TODO: Add dependency on JWT token +router = AirflowRouter( + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) + + +@router.get("/by-name") +def get_asset_by_name( + name: Annotated[str, Query(description="The name of the Asset")], + session: SessionDep, +) -> AssetResponse: + """Get an Airflow Asset by `name`.""" + asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has())) + _raise_if_not_found(asset, f"Asset with name {name} not found") + + return AssetResponse.model_validate(asset) + + +@router.get("/by-uri") +def get_asset_by_uri( + uri: Annotated[str, Query(description="The URI of the Asset")], + session: SessionDep, +) -> AssetResponse: + """Get an Airflow Asset by `uri`.""" + asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has())) + _raise_if_not_found(asset, f"Asset with URI {uri} not found") + + return AssetResponse.model_validate(asset) + + +def _raise_if_not_found(asset, msg): + if asset is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": msg, + }, + ) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 11c293b531fa6..d828a9a5b6b24 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -64,6 +64,7 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -77,10 +78,8 @@ from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import ( - AssetAliasEvent, ConnectionAccessor, Context, - OutletEventAccessor, OutletEventAccessors, VariableAccessor, ) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 1f453457e4323..168243290fabc 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -19,7 +19,6 @@ from __future__ import annotations -import contextlib from collections.abc import ( Container, Iterator, @@ -51,9 +50,9 @@ AssetRef, AssetUniqueKey, AssetUriRef, - BaseAssetUniqueKey, ) from airflow.sdk.definitions.context import Context +from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK from airflow.utils.db import LazySelectSequence from airflow.utils.session import create_session from airflow.utils.types import NOTSET @@ -156,104 +155,29 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn -@attrs.define() -class AssetAliasEvent: - """ - Represeation of asset event to be triggered by an asset alias. - - :meta private: - """ - - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] - - -@attrs.define() -class OutletEventAccessor: - """ - Wrapper to access an outlet asset event in template. - - :meta private: - """ - - key: BaseAssetUniqueKey - extra: dict[str, Any] = attrs.Factory(dict) - asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: - """Add an AssetEvent to an existing Asset.""" - if not isinstance(self.key, AssetAliasUniqueKey): - return - - asset_alias_name = self.key.name - event = AssetAliasEvent( - source_alias_name=asset_alias_name, - dest_asset_key=AssetUniqueKey.from_asset(asset), - extra=extra or {}, - ) - self.asset_alias_events.append(event) - - -class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): +class OutletEventAccessors(OutletEventAccessorsSDK): """ Lazy mapping of outlet asset event accessors. :meta private: """ - _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} - - def __init__(self) -> None: - self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} - - def __str__(self) -> str: - return f"OutletEventAccessors(_dict={self._dict})" - - def __iter__(self) -> Iterator[Asset | AssetAlias]: - return ( - key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict - ) - - def __len__(self) -> int: - return len(self._dict) - - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: - hashable_key: BaseAssetUniqueKey - if isinstance(key, Asset): - hashable_key = AssetUniqueKey.from_asset(key) - elif isinstance(key, AssetAlias): - hashable_key = AssetAliasUniqueKey.from_asset_alias(key) - elif isinstance(key, AssetRef): - hashable_key = self._resolve_asset_ref(key) - else: - raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") - - if hashable_key not in self._dict: - self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) - return self._dict[hashable_key] - - def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: - with contextlib.suppress(KeyError): - return self._asset_ref_cache[ref] - - refs_to_cache: list[AssetRef] - with create_session() as session: - if isinstance(ref, AssetNameRef): + @staticmethod + def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: + if name: + with create_session() as session: asset = session.scalar( - select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has()) + select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()) ) - refs_to_cache = [ref, AssetUriRef(asset.uri)] - elif isinstance(ref, AssetUriRef): + elif uri: + with create_session() as session: asset = session.scalar( - select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has()) + select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()) ) - refs_to_cache = [ref, AssetNameRef(asset.name)] - else: - raise TypeError(f"Unimplemented asset ref: {type(ref)}") - for ref in refs_to_cache: - self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset) - return unique_key + else: + raise ValueError("Either name or uri must be provided") + + return asset.to_public() class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 5ee270591481e..e73e5aebea64b 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -34,6 +34,7 @@ from airflow.sdk import __version__ from airflow.sdk.api.datamodels._generated import ( + AssetResponse, ConnectionResponse, DagRunType, TerminalTIState, @@ -267,6 +268,24 @@ def set( return {"ok": True} +class AssetOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse: + """Get Asset value from the API server.""" + if name: + resp = self.client.get("assets/by-name", params={"name": name}) + elif uri: + resp = self.client.get("assets/by-uri", params={"uri": uri}) + else: + raise ValueError("Either `name` or `uri` must be provided") + + return AssetResponse.model_validate_json(resp.read()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -374,6 +393,12 @@ def xcoms(self) -> XComOperations: """Operations related to XComs.""" return XComOperations(self) + @lru_cache() # type: ignore[misc] + @property + def assets(self) -> AssetOperations: + """Operations related to XComs.""" + return AssetOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index a8b478d07f029..f0a04da21c894 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -29,6 +29,15 @@ from pydantic import BaseModel, ConfigDict, Field +class AssetAliasResponse(BaseModel): + """ + Asset alias schema with fields that are needed for Runtime. + """ + + name: Annotated[str, Field(title="Name")] + group: Annotated[str, Field(title="Group")] + + class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. @@ -187,6 +196,17 @@ class TaskInstance(BaseModel): hostname: Annotated[str | None, Field(title="Hostname")] = None +class AssetResponse(BaseModel): + """ + Asset schema for responses with fields that are needed for Runtime. + """ + + name: Annotated[str, Field(title="Name")] + uri: Annotated[str, Field(title="Uri")] + group: Annotated[str, Field(title="Group")] + extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + + class DagRun(BaseModel): """ Schema for DagRun model with minimal required fields needed for Runtime. diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 5b0cbb4a784d9..ea89f1b681701 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -488,14 +488,14 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat ) -@attrs.define() +@attrs.define(hash=True) class AssetNameRef(AssetRef): """Name reference to an asset.""" name: str -@attrs.define() +@attrs.define(hash=True) class AssetUriRef(AssetRef): """URI reference to an asset.""" diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index b6874d47f090c..f8aaab65af4f1 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -50,6 +50,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue from airflow.sdk.api.datamodels._generated import ( + AssetResponse, BundleInfo, ConnectionResponse, TaskInstance, @@ -79,6 +80,25 @@ class StartupDetails(BaseModel): type: Literal["StartupDetails"] = "StartupDetails" +class AssetResult(AssetResponse): + """Response to ReadXCom request.""" + + type: Literal["AssetResult"] = "AssetResult" + + @classmethod + def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult: + """ + Get AssetResult from AssetResponse. + + AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult + for communication between the Supervisor and the task process. + """ + # Exclude defaults to avoid sending unnecessary data + # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True + # to avoid sending unset fields (which are defaults in our case). + return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult") + + class XComResult(XComResponse): """Response to ReadXCom request.""" @@ -133,7 +153,7 @@ class ErrorResponse(BaseModel): ToTask = Annotated[ - Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse], + Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse, AssetResult], Field(discriminator="type"), ] @@ -231,12 +251,24 @@ class SetRenderedFields(BaseModel): type: Literal["SetRenderedFields"] = "SetRenderedFields" +class GetAssetByName(BaseModel): + name: str + type: Literal["GetAssetByName"] = "GetAssetByName" + + +class GetAssetByUri(BaseModel): + uri: str + type: Literal["GetAssetByUri"] = "GetAssetByUri" + + ToSupervisor = Annotated[ Union[ TaskState, GetXCom, GetConnection, GetVariable, + GetAssetByName, + GetAssetByUri, DeferTask, PutVariable, SetXCom, diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index cdb3880bb36b3..918526c3004c2 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -17,20 +17,31 @@ from __future__ import annotations import contextlib -from collections.abc import Generator -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Union +import attrs import structlog from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT from airflow.sdk.definitions._internal.types import NOTSET +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasUniqueKey, + AssetNameRef, + AssetRef, + AssetUniqueKey, + AssetUriRef, + BaseAssetUniqueKey, +) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable - from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult + from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, VariableResult log = structlog.get_logger(logger_name="task") @@ -163,6 +174,112 @@ def __eq__(self, other: object) -> bool: return True +@attrs.define +class AssetAliasEvent: + """Representation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_key: AssetUniqueKey + extra: dict[str, Any] + + +@attrs.define +class OutletEventAccessor: + """Wrapper to access an outlet asset event in template.""" + + key: BaseAssetUniqueKey + extra: dict[str, Any] = attrs.Factory(dict) + asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) + + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: + """Add an AssetEvent to an existing Asset.""" + if not isinstance(self.key, AssetAliasUniqueKey): + return + + asset_alias_name = self.key.name + event = AssetAliasEvent( + source_alias_name=asset_alias_name, + dest_asset_key=AssetUniqueKey.from_asset(asset), + extra=extra or {}, + ) + self.asset_alias_events.append(event) + + +class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): + """Lazy mapping of outlet asset event accessors.""" + + _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} + + def __init__(self) -> None: + self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} + + def __str__(self) -> str: + return f"OutletEventAccessors(_dict={self._dict})" + + def __iter__(self) -> Iterator[Asset | AssetAlias]: + return ( + key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict + ) + + def __len__(self) -> int: + return len(self._dict) + + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: + hashable_key: BaseAssetUniqueKey + if isinstance(key, Asset): + hashable_key = AssetUniqueKey.from_asset(key) + elif isinstance(key, AssetAlias): + hashable_key = AssetAliasUniqueKey.from_asset_alias(key) + elif isinstance(key, AssetRef): + hashable_key = self._resolve_asset_ref(key) + else: + raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") + + if hashable_key not in self._dict: + self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) + return self._dict[hashable_key] + + def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: + with contextlib.suppress(KeyError): + return self._asset_ref_cache[ref] + + refs_to_cache: list[AssetRef] + if isinstance(ref, AssetNameRef): + asset = self._get_asset_from_db(name=ref.name) + refs_to_cache = [ref, AssetUriRef(asset.uri)] + elif isinstance(ref, AssetUriRef): + asset = self._get_asset_from_db(uri=ref.uri) + refs_to_cache = [ref, AssetNameRef(asset.name)] + else: + raise TypeError(f"Unimplemented asset ref: {type(ref)}") + unique_key = AssetUniqueKey.from_asset(asset) + for ref in refs_to_cache: + self._asset_ref_cache[ref] = unique_key + return unique_key + + # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py + @staticmethod + def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: + from airflow.sdk.definitions.asset import Asset + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + if name: + SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + elif uri: + SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + else: + raise ValueError("Either name or uri must be provided") + + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, AssetResult) + return Asset(**msg.model_dump(exclude={"type"})) + + @contextlib.contextmanager def set_current_context(context: Context) -> Generator[Context, None, None]: """ diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 32895d36524d8..bd50ee5126b94 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -61,8 +61,11 @@ VariableResponse, ) from airflow.sdk.execution_time.comms import ( + AssetResult, ConnectionResult, DeferTask, + GetAssetByName, + GetAssetByUri, GetConnection, GetVariable, GetXCom, @@ -787,6 +790,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, SetRenderedFields): self.client.task_instances.set_rtif(self.id, msg.rendered_fields) + elif isinstance(msg, GetAssetByName): + asset_resp = self.client.assets.get(name=msg.name) + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, GetAssetByUri): + asset_resp = self.client.assets.get(uri=msg.uri) + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result.model_dump_json(exclude_unset=True).encode() else: log.error("Unhandled request", msg=msg) return diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 186faac878a0a..d252c24be180c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -50,6 +50,7 @@ from airflow.sdk.execution_time.context import ( ConnectionAccessor, MacrosAccessor, + OutletEventAccessors, VariableAccessor, set_current_context, ) @@ -92,12 +93,13 @@ def get_template_context(self) -> Context: # TODO: Ensure that ti.log_url and such are available to use in context # especially after removal of `conf` from Context. "ti": self, - # "outlet_events": OutletEventAccessors(), + "outlet_events": OutletEventAccessors(), # "expanded_ti_count": expanded_ti_count, "expanded_ti_count": None, # TODO: Implement this # "inlet_events": InletEventsAccessors(task.inlets, session=session), "macros": MacrosAccessor(), # "params": validated_params, + # TODO: Make this go through Public API longer term. # "prev_data_interval_start_success": get_prev_data_interval_start_success(), # "prev_data_interval_end_success": get_prev_data_interval_end_success(), # "prev_start_date_success": get_prev_start_date_success(), diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 6527d517e375f..e3ef15dc934cf 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -22,12 +22,16 @@ import pytest from airflow.sdk import get_current_context +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType -from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult +from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.context import ( + AssetAliasEvent, ConnectionAccessor, + OutletEventAccessor, + OutletEventAccessors, VariableAccessor, _convert_connection_result_conn, _convert_variable_result_to_variable, @@ -248,3 +252,100 @@ def test_nested_context(self): assert ctx["ContextId"] == i # End of with statement ctx_list[i].__exit__(None, None, None) + + +class TestOutletEventAccessor: + @pytest.mark.parametrize( + "key, asset_alias_events", + ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"), + extra={}, + ) + ], + ), + ), + ) + def test_add(self, key, asset_alias_events, mock_supervisor_comms): + asset = Asset("test_uri") + mock_supervisor_comms.get_message.return_value = asset + + outlet_event_accessor = OutletEventAccessor(key=key, extra={}) + outlet_event_accessor.add(asset) + assert outlet_event_accessor.asset_alias_events == asset_alias_events + + @pytest.mark.parametrize( + "key, asset_alias_events", + ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), + extra={}, + ) + ], + ), + ), + ) + def test_add_with_db(self, key, asset_alias_events, mock_supervisor_comms): + asset = Asset(uri="test://asset-uri", name="test-asset") + mock_supervisor_comms.get_message.return_value = asset + + outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) + outlet_event_accessor.add(asset, extra={}) + assert outlet_event_accessor.asset_alias_events == asset_alias_events + + +class TestOutletEventAccessors: + @pytest.mark.parametrize( + "access_key, internal_key", + ( + (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), + ( + Asset(name="test", uri="test://asset"), + AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")), + ), + (AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))), + ), + ) + def test__get_item__dict_key_not_exists(self, access_key, internal_key): + outlet_event_accessors = OutletEventAccessors() + assert len(outlet_event_accessors) == 0 + outlet_event_accessor = outlet_event_accessors[access_key] + assert len(outlet_event_accessors) == 1 + assert outlet_event_accessor.key == internal_key + assert outlet_event_accessor.extra == {} + + @pytest.mark.parametrize( + ["access_key", "asset"], + ( + (Asset.ref(name="test"), Asset(name="test")), + (Asset.ref(name="test1"), Asset(name="test1", uri="test://asset-uri")), + (Asset.ref(uri="test://asset-uri"), Asset(uri="test://asset-uri")), + ), + ) + def test__get_item__asset_ref(self, access_key, asset, mock_supervisor_comms): + """Test accessing OutletEventAccessors with AssetRef resolves to correct Asset.""" + internal_key = AssetUniqueKey.from_asset(asset) + outlet_event_accessors = OutletEventAccessors() + assert len(outlet_event_accessors) == 0 + + # Asset from the API Server via the supervisor + mock_supervisor_comms.get_message.return_value = AssetResult( + name=asset.name, + uri=asset.uri, + group=asset.group, + ) + + outlet_event_accessor = outlet_event_accessors[access_key] + assert len(outlet_event_accessors) == 1 + assert outlet_event_accessor.key == internal_key + assert outlet_event_accessor.extra == {} diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 59afa26dc2aa5..5455d0f70cdef 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -41,8 +41,11 @@ from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( + AssetResult, ConnectionResult, DeferTask, + GetAssetByName, + GetAssetByUri, GetConnection, GetVariable, GetXCom, @@ -805,13 +808,14 @@ def watched_subprocess(self, mocker): ) @pytest.mark.parametrize( - ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], + ["message", "expected_buffer", "client_attr_path", "method_arg", "method_kwarg", "mock_response"], [ pytest.param( GetConnection(conn_id="test_conn"), b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', "connections.get", ("test_conn",), + {}, ConnectionResult(conn_id="test_conn", conn_type="mysql"), id="get_connection", ), @@ -820,6 +824,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', "variables.get", ("test_key",), + {}, VariableResult(key="test_key", value="test_value"), id="get_variable", ), @@ -828,6 +833,7 @@ def watched_subprocess(self, mocker): b"", "variables.set", ("test_key", "test_value", "test_description"), + {}, {"ok": True}, id="set_variable", ), @@ -836,6 +842,7 @@ def watched_subprocess(self, mocker): b"", "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + {}, "", id="patch_task_instance_to_deferred", ), @@ -853,6 +860,7 @@ def watched_subprocess(self, mocker): end_date=timezone.parse("2024-10-31T12:00:00Z"), ), ), + {}, "", id="patch_task_instance_to_up_for_reschedule", ), @@ -861,6 +869,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None), + {}, XComResult(key="test_key", value="test_value"), id="get_xcom", ), @@ -871,6 +880,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2), + {}, XComResult(key="test_key", value="test_value"), id="get_xcom_map_index", ), @@ -879,6 +889,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":null,"type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None), + {}, XComResult(key="test_key", value=None, type="XComResult"), id="get_xcom_not_found", ), @@ -900,6 +911,7 @@ def watched_subprocess(self, mocker): '{"key": "test_key", "value": {"key2": "value2"}}', None, ), + {}, {"ok": True}, id="set_xcom", ), @@ -922,6 +934,7 @@ def watched_subprocess(self, mocker): '{"key": "test_key", "value": {"key2": "value2"}}', 2, ), + {}, {"ok": True}, id="set_xcom_with_map_index", ), @@ -932,6 +945,7 @@ def watched_subprocess(self, mocker): b"", "", (), + {}, "", id="patch_task_instance_to_skipped", ), @@ -940,9 +954,28 @@ def watched_subprocess(self, mocker): b"", "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), + {}, {"ok": True}, id="set_rtif", ), + pytest.param( + GetAssetByName(name="asset"), + b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + "assets.get", + [], + {"name": "asset"}, + AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + id="get_asset_by_name", + ), + pytest.param( + GetAssetByUri(uri="s3://bucket/obj"), + b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + "assets.get", + [], + {"uri": "s3://bucket/obj"}, + AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + id="get_asset_by_uri", + ), ], ) def test_handle_requests( @@ -953,8 +986,8 @@ def test_handle_requests( expected_buffer, client_attr_path, method_arg, + method_kwarg, mock_response, - time_machine, ): """ Test handling of different messages to the subprocess. For any new message type, add a @@ -980,7 +1013,7 @@ def test_handle_requests( # Verify the correct client method was called if client_attr_path: - mock_client_method.assert_called_once_with(*method_arg) + mock_client_method.assert_called_once_with(*method_arg, **method_kwarg) # Verify the response was added to the buffer val = watched_subprocess.stdin.getvalue() diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 60b39da455c69..f7734279b3ffb 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -52,7 +52,12 @@ VariableResult, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ( + ConnectionAccessor, + MacrosAccessor, + OutletEventAccessors, + VariableAccessor, +) from airflow.sdk.execution_time.task_runner import ( CommsDecoder, RuntimeTaskInstance, @@ -613,6 +618,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ "inlets": task.inlets, "macros": MacrosAccessor(), "map_index_template": task.map_index_template, + "outlet_events": OutletEventAccessors(), "outlets": task.outlets, "run_id": "test_run", "task": task, @@ -645,6 +651,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti): "inlets": task.inlets, "macros": MacrosAccessor(), "map_index_template": task.map_index_template, + "outlet_events": OutletEventAccessors(), "outlets": task.outlets, "run_id": "test_run", "task": task, diff --git a/tests/api_fastapi/execution_api/routes/test_assets.py b/tests/api_fastapi/execution_api/routes/test_assets.py new file mode 100644 index 0000000000000..2cf34f8dd7bc7 --- /dev/null +++ b/tests/api_fastapi/execution_api/routes/test_assets.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.models.asset import AssetActive, AssetModel +from airflow.utils import timezone + +DEFAULT_DATE = timezone.parse("2021-01-01T00:00:00") + +pytestmark = pytest.mark.db_test + + +class TestGetAssetByName: + def test_get_asset_by_name(self, client, session): + asset = AssetModel( + id=1, + name="test_get_asset_by_name", + uri="s3://bucket/key", + group="asset", + extra={"foo": "bar"}, + created_at=DEFAULT_DATE, + updated_at=DEFAULT_DATE, + ) + + asset_active = AssetActive.for_asset(asset) + + session.add_all([asset, asset_active]) + session.commit() + + response = client.get("/execution/assets/by-name", params={"name": "test_get_asset_by_name"}) + + assert response.status_code == 200 + assert response.json() == { + "name": "test_get_asset_by_name", + "uri": "s3://bucket/key", + "group": "asset", + "extra": {"foo": "bar"}, + } + + session.delete(asset) + session.delete(asset_active) + session.commit() + + def test_asset_name_not_found(self, client): + response = client.get("/execution/assets/by-name", params={"name": "non_existent"}) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Asset with name non_existent not found", + "reason": "not_found", + } + } + + +class TestGetAssetByUri: + def test_get_asset_by_uri(self, client, session): + asset = AssetModel( + name="test_get_asset_by_uri", + uri="s3://bucket/key", + group="asset", + extra={"foo": "bar"}, + ) + + asset_active = AssetActive.for_asset(asset) + + session.add_all([asset, asset_active]) + session.commit() + + response = client.get("/execution/assets/by-uri", params={"uri": "s3://bucket/key"}) + + assert response.status_code == 200 + assert response.json() == { + "name": "test_get_asset_by_uri", + "uri": "s3://bucket/key", + "group": "asset", + "extra": {"foo": "bar"}, + } + + session.delete(asset) + session.delete(asset_active) + session.commit() + + def test_asset_uri_not_found(self, client): + response = client.get("/execution/assets/by-uri", params={"uri": "non_existent"}) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Asset with URI non_existent not found", + "reason": "not_found", + } + } diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 0faeed038e648..707595b92ffa2 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -43,11 +43,12 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors +from airflow.utils.context import OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py deleted file mode 100644 index 0046ca33cc4da..0000000000000 --- a/tests/utils/test_context.py +++ /dev/null @@ -1,102 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import pytest - -from airflow.models.asset import AssetActive, AssetAliasModel, AssetModel -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey -from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors - - -class TestOutletEventAccessor: - @pytest.mark.parametrize( - "key, asset_alias_events", - ( - (AssetUniqueKey.from_asset(Asset("test_uri")), []), - ( - AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), - [ - AssetAliasEvent( - source_alias_name="test_alias", - dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"), - extra={}, - ) - ], - ), - ), - ) - @pytest.mark.db_test - def test_add(self, key, asset_alias_events, session): - asset = Asset("test_uri") - session.add_all([AssetModel.from_public(asset), AssetActive.for_asset(asset)]) - session.flush() - - outlet_event_accessor = OutletEventAccessor(key=key, extra={}) - outlet_event_accessor.add(asset) - assert outlet_event_accessor.asset_alias_events == asset_alias_events - - @pytest.mark.db_test - @pytest.mark.parametrize( - "key, asset_alias_events", - ( - (AssetUniqueKey.from_asset(Asset("test_uri")), []), - ( - AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), - [ - AssetAliasEvent( - source_alias_name="test_alias", - dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), - extra={}, - ) - ], - ), - ), - ) - def test_add_with_db(self, key, asset_alias_events, session): - asset = Asset(uri="test://asset-uri", name="test-asset") - asm = AssetModel.from_public(asset) - aam = AssetAliasModel(name="test_alias") - session.add_all([asm, aam, AssetActive.for_asset(asset)]) - session.flush() - - outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) - outlet_event_accessor.add(asset, extra={}) - assert outlet_event_accessor.asset_alias_events == asset_alias_events - - -class TestOutletEventAccessors: - @pytest.mark.parametrize( - "access_key, internal_key", - ( - (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), - ( - Asset(name="test", uri="test://asset"), - AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")), - ), - (AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))), - ), - ) - def test___get_item__dict_key_not_exists(self, access_key, internal_key): - outlet_event_accessors = OutletEventAccessors() - assert len(outlet_event_accessors) == 0 - outlet_event_accessor = outlet_event_accessors[access_key] - assert len(outlet_event_accessors) == 1 - assert outlet_event_accessor.key == internal_key - assert outlet_event_accessor.extra == {}