Skip to content

Commit

Permalink
AIP-72: Add support for outlet_events in Task Context
Browse files Browse the repository at this point in the history
part of #45717

This PR adds support for `outlet_events` in Context dict within the Task SDK by adding an endpoint on the API Server which is fetched when outlet_events is accessed.
  • Loading branch information
kaxil committed Jan 17, 2025
1 parent 9984dcd commit 8c04b05
Show file tree
Hide file tree
Showing 18 changed files with 603 additions and 208 deletions.
36 changes: 36 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.api_fastapi.core_api.base import BaseModel


class AssetResponse(BaseModel):
"""Asset schema for responses with fields that are needed for Runtime."""

name: str
uri: str
group: str
extra: dict | None = None


class AssetAliasResponse(BaseModel):
"""Asset alias schema with fields that are needed for Runtime."""

name: str
group: str
10 changes: 9 additions & 1 deletion airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
from __future__ import annotations

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms
from airflow.api_fastapi.execution_api.routes import (
assets,
connections,
health,
task_instances,
variables,
xcoms,
)

execution_api_router = AirflowRouter()
execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
Expand Down
71 changes: 71 additions & 0 deletions airflow/api_fastapi/execution_api/routes/assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Annotated

from fastapi import HTTPException, Query, status
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
from airflow.models.asset import AssetModel

# TODO: Add dependency on JWT token
router = AirflowRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)


@router.get("/by-name")
def get_asset_by_name(
name: Annotated[str, Query(description="The name of the Asset")],
session: SessionDep,
) -> AssetResponse:
"""Get an Airflow Asset by `name`."""
asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()))
_raise_if_not_found(asset, f"Asset with name {name} not found")

return AssetResponse.model_validate(asset)


@router.get("/by-uri")
def get_asset_by_uri(
uri: Annotated[str, Query(description="The URI of the Asset")],
session: SessionDep,
) -> AssetResponse:
"""Get an Airflow Asset by `uri`."""
asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()))
_raise_if_not_found(asset, f"Asset with URI {uri} not found")

return AssetResponse.model_validate(asset)


def _raise_if_not_found(asset, msg):
if asset is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": msg,
},
)
3 changes: 1 addition & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
Expand All @@ -77,10 +78,8 @@
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import (
AssetAliasEvent,
ConnectionAccessor,
Context,
OutletEventAccessor,
OutletEventAccessors,
VariableAccessor,
)
Expand Down
104 changes: 14 additions & 90 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import annotations

import contextlib
from collections.abc import (
Container,
Iterator,
Expand Down Expand Up @@ -51,9 +50,9 @@
AssetRef,
AssetUniqueKey,
AssetUriRef,
BaseAssetUniqueKey,
)
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK
from airflow.utils.db import LazySelectSequence
from airflow.utils.session import create_session
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -156,104 +155,29 @@ def get(self, key: str, default_conn: Any = None) -> Any:
return default_conn


@attrs.define()
class AssetAliasEvent:
"""
Represeation of asset event to be triggered by an asset alias.
:meta private:
"""

source_alias_name: str
dest_asset_key: AssetUniqueKey
extra: dict[str, Any]


@attrs.define()
class OutletEventAccessor:
"""
Wrapper to access an outlet asset event in template.
:meta private:
"""

key: BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
"""Add an AssetEvent to an existing Asset."""
if not isinstance(self.key, AssetAliasUniqueKey):
return

asset_alias_name = self.key.name
event = AssetAliasEvent(
source_alias_name=asset_alias_name,
dest_asset_key=AssetUniqueKey.from_asset(asset),
extra=extra or {},
)
self.asset_alias_events.append(event)


class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]):
class OutletEventAccessors(OutletEventAccessorsSDK):
"""
Lazy mapping of outlet asset event accessors.
:meta private:
"""

_asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}

def __init__(self) -> None:
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[Asset | AssetAlias]:
return (
key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict
)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor:
hashable_key: BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
elif isinstance(key, AssetRef):
hashable_key = self._resolve_asset_ref(key)
else:
raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")

if hashable_key not in self._dict:
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
return self._dict[hashable_key]

def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
with contextlib.suppress(KeyError):
return self._asset_ref_cache[ref]

refs_to_cache: list[AssetRef]
with create_session() as session:
if isinstance(ref, AssetNameRef):
@staticmethod
def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset:
if name:
with create_session() as session:
asset = session.scalar(
select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has())
select(AssetModel).where(AssetModel.name == name, AssetModel.active.has())
)
refs_to_cache = [ref, AssetUriRef(asset.uri)]
elif isinstance(ref, AssetUriRef):
elif uri:
with create_session() as session:
asset = session.scalar(
select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has())
select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has())
)
refs_to_cache = [ref, AssetNameRef(asset.name)]
else:
raise TypeError(f"Unimplemented asset ref: {type(ref)}")
for ref in refs_to_cache:
self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset)
return unique_key
else:
raise ValueError("Either name or uri must be provided")

return asset.to_public()


class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
Expand Down
25 changes: 25 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
ConnectionResponse,
DagRunType,
TerminalTIState,
Expand Down Expand Up @@ -267,6 +268,24 @@ def set(
return {"ok": True}


class AssetOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse:
"""Get Asset value from the API server."""
if name:
resp = self.client.get("assets/by-name", params={"name": name})
elif uri:
resp = self.client.get("assets/by-uri", params={"uri": uri})
else:
raise ValueError("Either `name` or `uri` must be provided")

return AssetResponse.model_validate_json(resp.read())


class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
Expand Down Expand Up @@ -374,6 +393,12 @@ def xcoms(self) -> XComOperations:
"""Operations related to XComs."""
return XComOperations(self)

@lru_cache() # type: ignore[misc]
@property
def assets(self) -> AssetOperations:
"""Operations related to XComs."""
return AssetOperations(self)


# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
Expand Down
20 changes: 20 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
from pydantic import BaseModel, ConfigDict, Field


class AssetAliasResponse(BaseModel):
"""
Asset alias schema with fields that are needed for Runtime.
"""

name: Annotated[str, Field(title="Name")]
group: Annotated[str, Field(title="Group")]


class ConnectionResponse(BaseModel):
"""
Connection schema for responses with fields that are needed for Runtime.
Expand Down Expand Up @@ -187,6 +196,17 @@ class TaskInstance(BaseModel):
hostname: Annotated[str | None, Field(title="Hostname")] = None


class AssetResponse(BaseModel):
"""
Asset schema for responses with fields that are needed for Runtime.
"""

name: Annotated[str, Field(title="Name")]
uri: Annotated[str, Field(title="Uri")]
group: Annotated[str, Field(title="Group")]
extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None


class DagRun(BaseModel):
"""
Schema for DagRun model with minimal required fields needed for Runtime.
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,14 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat
)


@attrs.define()
@attrs.define(hash=True)
class AssetNameRef(AssetRef):
"""Name reference to an asset."""

name: str


@attrs.define()
@attrs.define(hash=True)
class AssetUriRef(AssetRef):
"""URI reference to an asset."""

Expand Down
Loading

0 comments on commit 8c04b05

Please sign in to comment.