Skip to content

Commit

Permalink
feat(decorators/asset): move @asset to task_sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Nov 15, 2024
1 parent 60b06f1 commit 0521112
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 138 deletions.
131 changes: 0 additions & 131 deletions airflow/decorators/assets.py

This file was deleted.

4 changes: 2 additions & 2 deletions airflow/example_dags/example_asset_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import pendulum

from airflow.assets import Asset
from airflow.decorators import dag, task
from airflow.decorators.assets import asset
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.decorators import asset


@asset(uri="s3://bucket/asset1_producer", schedule=None)
Expand Down
115 changes: 115 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@

from __future__ import annotations

import inspect
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Any, Callable

import attrs

from airflow.models.asset import _fetch_active_assets_by_name
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.sdk.definitions.dag import DAG, ScheduleArg
from airflow.utils.session import create_session

if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath


import sys
from types import FunctionType

Expand All @@ -40,3 +56,102 @@ def fixup_decorator_warning_stack(func: FunctionType):
# Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to
# `warnings.warn` to ignore the decorator.
func.__globals__["warnings"] = _autostacklevel_warn()


class _AssetMainOperator(PythonOperator):
def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)
self._definition_name = definition_name
self._uri = uri

def _iter_kwargs(
self, context: Mapping[str, Any], active_assets: dict[str, Asset]
) -> Iterator[tuple[str, Any]]:
value: Any
for key in inspect.signature(self.python_callable).parameters:
if key == "self":
value = active_assets.get(self._definition_name)
elif key == "context":
value = context
else:
value = active_assets.get(key, Asset(name=key))
yield key, value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
active_assets: dict[str, Asset] = {}
asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)]
if "self" in inspect.signature(self.python_callable).parameters:
asset_names.append(self._definition_name)

if asset_names:
with create_session() as session:
active_assets = _fetch_active_assets_by_name(asset_names, session)
return dict(self._iter_kwargs(context, active_assets))


@attrs.define(kw_only=True)
class AssetDefinition(Asset):
"""
Asset representation from decorating a function with ``@asset``.
:meta private:
"""

function: Callable
schedule: ScheduleArg

def __attrs_post_init__(self) -> None:
parameters = inspect.signature(self.function).parameters

with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
_AssetMainOperator(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
for inlet_asset_name in parameters
if inlet_asset_name not in ("self", "context")
],
outlets=[self.to_asset()],
python_callable=self.function,
definition_name=self.name,
uri=self.uri,
)

def to_asset(self) -> Asset:
return Asset(
name=self.name,
uri=self.uri,
group=self.group,
extra=self.extra,
)

def serialize(self):
return {
"uri": self.uri,
"name": self.name,
"group": self.group,
"extra": self.extra,
}


@attrs.define(kw_only=True)
class asset:
"""Create an asset by decorating a materialization function."""

schedule: ScheduleArg
uri: str | ObjectStoragePath | None = None
group: str = ""
extra: dict[str, Any] = attrs.field(factory=dict)

def __call__(self, f: Callable) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
raise ValueError("nested function not supported")

return AssetDefinition(
name=name,
uri=name if self.uri is None else str(self.uri),
group=self.group,
extra=self.extra,
function=f,
schedule=self.schedule,
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import pytest

from airflow.assets import Asset
from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset
from airflow.models.asset import AssetActive, AssetModel
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.decorators import AssetRef, _AssetMainOperator, asset

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -119,8 +119,8 @@ def test_serialzie(self, example_asset_definition):
"uri": "s3://bucket/object",
}

@mock.patch("airflow.decorators.assets._AssetMainOperator")
@mock.patch("airflow.decorators.assets.DAG")
@mock.patch("airflow.sdk.definitions.decorators._AssetMainOperator")
@mock.patch("airflow.sdk.definitions.decorators.DAG")
def test__attrs_post_init__(
self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset
):
Expand Down Expand Up @@ -169,7 +169,10 @@ def test_determine_kwargs(self, example_asset_func_with_valid_arg_as_inlet_asset
)
assert op.determine_kwargs(context={"k": "v"}) == {
"self": Asset(
name="example_asset_func", uri="s3://bucket/object", group="MLModel", extra={"k": "v"}
name="example_asset_func",
uri="s3://bucket/object",
group="MLModel",
extra={"k": "v"},
),
"context": {"k": "v"},
"inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"),
Expand Down

0 comments on commit 0521112

Please sign in to comment.