Skip to content

Commit

Permalink
AIP-72: Add support for outlet_events in Task Context
Browse files Browse the repository at this point in the history
part of #45717

This PR adds support for `outlet_events` in Context dict within the Task SDK by adding an endpoint on the API Server which is fetched when outlet_events is accessed.
  • Loading branch information
kaxil committed Jan 17, 2025
1 parent 060eeb7 commit 943ecd7
Show file tree
Hide file tree
Showing 17 changed files with 679 additions and 105 deletions.
62 changes: 62 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from datetime import datetime

from airflow.api_fastapi.core_api.base import BaseModel


class DagScheduleAssetReference(BaseModel):
"""DAG schedule reference schema for assets."""

dag_id: str
created_at: datetime
updated_at: datetime


class TaskOutletAssetReference(BaseModel):
"""Task outlet reference schema for assets."""

dag_id: str
task_id: str
created_at: datetime
updated_at: datetime


class AssetResponse(BaseModel):
"""Asset schema for responses."""

id: int
name: str
uri: str
group: str
extra: dict | None = None
created_at: datetime
updated_at: datetime
consuming_dags: list[DagScheduleAssetReference]
producing_tasks: list[TaskOutletAssetReference]
aliases: list[AssetAliasResponse]


class AssetAliasResponse(BaseModel):
"""Asset alias schema for responses."""

id: int
name: str
group: str
10 changes: 9 additions & 1 deletion airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
from __future__ import annotations

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms
from airflow.api_fastapi.execution_api.routes import (
assets,
connections,
health,
task_instances,
variables,
xcoms,
)

execution_api_router = AirflowRouter()
execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
Expand Down
79 changes: 79 additions & 0 deletions airflow/api_fastapi/execution_api/routes/assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Annotated

from fastapi import HTTPException, Query, status
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
from airflow.models.asset import AssetModel

# TODO: Add dependency on JWT token
router = AirflowRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)


@router.get(
"",
responses={
status.HTTP_400_BAD_REQUEST: {
"description": "Either 'name' or 'uri' query parameter must be provided"
},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_asset(
session: SessionDep,
name: Annotated[str | None, Query(description="The name of the Asset")] = None,
uri: Annotated[str | None, Query(description="The URI of the Asset")] = None,
) -> AssetResponse:
"""Get an Airflow Asset by `name` or `uri`."""
if name:
asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()))
_raise_if_not_found(asset, f"Asset with name {name} not found")
elif uri:
asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()))
_raise_if_not_found(asset, f"Asset with URI {uri} not found")
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail={
"reason": "bad_request",
"message": "Either 'name' or 'uri' query parameter must be provided",
},
)
return AssetResponse.model_validate(asset)


def _raise_if_not_found(asset, msg):
if asset is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": msg,
},
)
3 changes: 1 addition & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
Expand All @@ -77,10 +78,8 @@
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import (
AssetAliasEvent,
ConnectionAccessor,
Context,
OutletEventAccessor,
OutletEventAccessors,
VariableAccessor,
)
Expand Down
105 changes: 14 additions & 91 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from __future__ import annotations

import contextlib
from collections.abc import (
Callable,
Container,
Iterator,
Mapping,
Expand Down Expand Up @@ -51,9 +51,9 @@
AssetRef,
AssetUniqueKey,
AssetUriRef,
BaseAssetUniqueKey,
)
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK
from airflow.utils.db import LazySelectSequence
from airflow.utils.session import create_session
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -156,104 +156,27 @@ def get(self, key: str, default_conn: Any = None) -> Any:
return default_conn


@attrs.define()
class AssetAliasEvent:
"""
Represeation of asset event to be triggered by an asset alias.
:meta private:
"""
def _get_asset(name: str | None = None, uri: str | None = None) -> Asset:
if name:
with create_session() as session:
asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()))
elif uri:
with create_session() as session:
asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()))
else:
raise ValueError("Either name or uri must be provided")

source_alias_name: str
dest_asset_key: AssetUniqueKey
extra: dict[str, Any]
return asset.to_public()


@attrs.define()
class OutletEventAccessor:
"""
Wrapper to access an outlet asset event in template.
:meta private:
"""

key: BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
"""Add an AssetEvent to an existing Asset."""
if not isinstance(self.key, AssetAliasUniqueKey):
return

asset_alias_name = self.key.name
event = AssetAliasEvent(
source_alias_name=asset_alias_name,
dest_asset_key=AssetUniqueKey.from_asset(asset),
extra=extra or {},
)
self.asset_alias_events.append(event)


class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]):
class OutletEventAccessors(OutletEventAccessorsSDK):
"""
Lazy mapping of outlet asset event accessors.
:meta private:
"""

_asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}

def __init__(self) -> None:
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[Asset | AssetAlias]:
return (
key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict
)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor:
hashable_key: BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
elif isinstance(key, AssetRef):
hashable_key = self._resolve_asset_ref(key)
else:
raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")

if hashable_key not in self._dict:
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
return self._dict[hashable_key]

def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
with contextlib.suppress(KeyError):
return self._asset_ref_cache[ref]

refs_to_cache: list[AssetRef]
with create_session() as session:
if isinstance(ref, AssetNameRef):
asset = session.scalar(
select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has())
)
refs_to_cache = [ref, AssetUriRef(asset.uri)]
elif isinstance(ref, AssetUriRef):
asset = session.scalar(
select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has())
)
refs_to_cache = [ref, AssetNameRef(asset.name)]
else:
raise TypeError(f"Unimplemented asset ref: {type(ref)}")
for ref in refs_to_cache:
self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset)
return unique_key
_get_asset_func: Callable[..., Asset] = _get_asset


class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
Expand Down
26 changes: 26 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
ConnectionResponse,
DagRunType,
TerminalTIState,
Expand Down Expand Up @@ -267,6 +268,25 @@ def set(
return {"ok": True}


class AssetOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse:
"""Get Asset value from the API server."""
if name:
params = {"name": name}
elif uri:
params = {"uri": uri}
else:
raise ValueError("Either `name` or `uri` must be provided")

resp = self.client.get("assets/", params=params)
return AssetResponse.model_validate_json(resp.read())


class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
Expand Down Expand Up @@ -374,6 +394,12 @@ def xcoms(self) -> XComOperations:
"""Operations related to XComs."""
return XComOperations(self)

@lru_cache() # type: ignore[misc]
@property
def assets(self) -> AssetOperations:
"""Operations related to XComs."""
return AssetOperations(self)


# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
Expand Down
Loading

0 comments on commit 943ecd7

Please sign in to comment.