Skip to content

Commit

Permalink
Add "@asset" to decorate a function as a DAG and an asset (#41325)
Browse files Browse the repository at this point in the history
* Implement asset definition creating a DAG

* Basic inlet dependency

* Make AssetDefinition subclass Asset

This seems to be the best way for 'schedule' dependencies to work. Still
not entirely sure; we'll revisit this.

* style: fix mypy error

* feat(asset): allow uri to be None

* fix: temporarily serialize AssetDefintion into a string

* feat(decorators/assets): rewrite how asset definition is serialized

* test(decorators/assets): add test cases to check whether asset decorator generate the right asset definition

* test(decorators/assets): add test cases to AssetDefinition

* test(decorators/asset): add test cases to Test_AssetMainOperator

* test(decorators/assets): remove unused fixtures

* docs(example_dag): add example dag for asset_decorator

* feat(decorators/assets): allow passing self and context into asset

* feat(decorators/assets): return actual asset in asset decorator

* refactor(decorators/assets): extract active assets fetching logic as _fetch_active_assets_by_name

* feat(decorators/assets): allow fethcing inlet events through AssetRef

* feat(decorators/assets): reorder import paths

* docs: update asset decorator example dag

* test: fix tests

* test(decorators/assets): extend test_determine_kwargs to cover active asset

* fix: address easy to fix comments

* fix: fix asset serialization

* refactor(decorators/assets): postpone the attribute check to AssetDefinition instead of asset decorator

* Simplify group validators

The validate_identifier validator already checks the length, so we don't
need an extra one doing that.

* style(dag): remove _wrapped_definition

* style(decorators/assets): change types.FunctionType to Callable

* refactor(decorators/assets): make session in _fetch_active_assets_by_name required

* fix(decorators/asets): remove DAG.bulk_write_to_db and remove self handling

* feat(utils/context): fetch asset_refs all at once

---------

Co-authored-by: Wei Lee <[email protected]>
  • Loading branch information
uranusjr and Lee-W authored Nov 14, 2024
1 parent 66d86f5 commit b3362f8
Show file tree
Hide file tree
Showing 13 changed files with 451 additions and 27 deletions.
34 changes: 22 additions & 12 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, cast, overload

import attr
import attrs
from sqlalchemy import select

from airflow.api_internal.internal_api_call import internal_api_call
Expand Down Expand Up @@ -123,6 +123,13 @@ def _validate_non_empty_identifier(instance, attribute, value):
return value


def _validate_asset_name(instance, attribute, value):
_validate_non_empty_identifier(instance, attribute, value)
if value == "self" or value == "context":
raise ValueError(f"prohibited name for asset: {value}")
return value


def extract_event_key(value: str | Asset | AssetAlias) -> str:
"""
Extract the key of an inlet or an outlet event.
Expand Down Expand Up @@ -158,6 +165,13 @@ def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SE
return []


@attrs.define(kw_only=True)
class AssetRef:
"""Reference to an asset."""

name: str


class BaseAsset:
"""
Protocol for all asset triggers to use in ``DAG(schedule=...)``.
Expand Down Expand Up @@ -207,16 +221,12 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
raise NotImplementedError


@attr.define(unsafe_hash=False)
@attrs.define(unsafe_hash=False)
class AssetAlias(BaseAsset):
"""A represeation of asset alias which is used to create asset during the runtime."""

name: str = attr.field(validator=_validate_non_empty_identifier)
group: str = attr.field(
kw_only=True,
default="",
validator=[attr.validators.max_len(1500), _validate_identifier],
)
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier)

def iter_assets(self) -> Iterator[tuple[str, Asset]]:
return iter(())
Expand Down Expand Up @@ -258,7 +268,7 @@ def _set_extra_default(extra: dict | None) -> dict:
return extra


@attr.define(init=False, unsafe_hash=False)
@attrs.define(init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""

Expand All @@ -267,7 +277,7 @@ class Asset(os.PathLike, BaseAsset):
group: str
extra: dict[str, Any]

asset_type: ClassVar[str] = ""
asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1

@overload
Expand Down Expand Up @@ -296,8 +306,8 @@ def __init__(
name = uri
elif uri is None:
uri = name
fields = attr.fields_dict(Asset)
self.name = _validate_non_empty_identifier(self, fields["name"], name)
fields = attrs.fields_dict(Asset)
self.name = _validate_asset_name(self, fields["name"], name)
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
Expand Down
131 changes: 131 additions & 0 deletions airflow/decorators/assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping

import attrs

from airflow.assets import Asset, AssetRef
from airflow.models.asset import _fetch_active_assets_by_name
from airflow.models.dag import DAG, ScheduleArg
from airflow.providers.standard.operators.python import PythonOperator
from airflow.utils.session import create_session

if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath


class _AssetMainOperator(PythonOperator):
def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)
self._definition_name = definition_name
self._uri = uri

def _iter_kwargs(
self, context: Mapping[str, Any], active_assets: dict[str, Asset]
) -> Iterator[tuple[str, Any]]:
value: Any
for key in inspect.signature(self.python_callable).parameters:
if key == "self":
value = active_assets.get(self._definition_name)
elif key == "context":
value = context
else:
value = active_assets.get(key, Asset(name=key))
yield key, value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
active_assets: dict[str, Asset] = {}
asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)]
if "self" in inspect.signature(self.python_callable).parameters:
asset_names.append(self._definition_name)

if asset_names:
with create_session() as session:
active_assets = _fetch_active_assets_by_name(asset_names, session)
return dict(self._iter_kwargs(context, active_assets))


@attrs.define(kw_only=True)
class AssetDefinition(Asset):
"""
Asset representation from decorating a function with ``@asset``.
:meta private:
"""

function: Callable
schedule: ScheduleArg

def __attrs_post_init__(self) -> None:
parameters = inspect.signature(self.function).parameters

with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
_AssetMainOperator(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
for inlet_asset_name in parameters
if inlet_asset_name not in ("self", "context")
],
outlets=[self.to_asset()],
python_callable=self.function,
definition_name=self.name,
uri=self.uri,
)

def to_asset(self) -> Asset:
return Asset(
name=self.name,
uri=self.uri,
group=self.group,
extra=self.extra,
)

def serialize(self):
return {
"uri": self.uri,
"name": self.name,
"group": self.group,
"extra": self.extra,
}


@attrs.define(kw_only=True)
class asset:
"""Create an asset by decorating a materialization function."""

schedule: ScheduleArg
uri: str | ObjectStoragePath | None = None
group: str = ""
extra: dict[str, Any] = attrs.field(factory=dict)

def __call__(self, f: Callable) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
raise ValueError("nested function not supported")

return AssetDefinition(
name=name,
uri=name if self.uri is None else str(self.uri),
group=self.group,
extra=self.extra,
function=f,
schedule=self.schedule,
)
52 changes: 52 additions & 0 deletions airflow/example_dags/example_asset_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pendulum

from airflow.assets import Asset
from airflow.decorators import dag, task
from airflow.decorators.assets import asset


@asset(uri="s3://bucket/asset1_producer", schedule=None)
def asset1_producer():
pass


@asset(uri="s3://bucket/object", schedule=None)
def asset2_producer(self, context, asset1_producer):
print(self)
print(context["inlet_events"][asset1_producer])


@dag(
schedule=Asset(uri="s3://bucket/asset1_producer", name="asset1_producer")
| Asset(uri="s3://bucket/object", name="asset2_producer"),
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["consumes", "asset-scheduled"],
)
def consumes_asset_decorator():
@task(outlets=[Asset(name="process_nothing")])
def process_nothing():
pass

process_nothing()


consumes_asset_decorator()
22 changes: 22 additions & 0 deletions airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from urllib.parse import urlsplit

import sqlalchemy_jsonfield
Expand All @@ -29,6 +30,7 @@
PrimaryKeyConstraint,
String,
Table,
select,
text,
)
from sqlalchemy.orm import relationship
Expand All @@ -39,6 +41,26 @@
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from typing import Sequence

from sqlalchemy.orm import Session


def _fetch_active_assets_by_name(
names: Sequence[str],
session: Session,
) -> dict[str, Asset]:
return {
asset_model[0].name: asset_model[0].to_public()
for asset_model in session.execute(
select(AssetModel)
.join(AssetActive, AssetActive.name == AssetModel.name)
.where(AssetActive.name.in_(name for name in names))
)
}


alias_association_table = Table(
"asset_alias_asset",
Base.metadata,
Expand Down
4 changes: 1 addition & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,9 +777,7 @@ def get_is_paused(self, session=NEW_SESSION) -> None:
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
return TaskSDKDag.get_serialized_fields() | {
"_processor_dags_folder",
}
return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"}

@staticmethod
@internal_api_call
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class DagAttributeTypes(str, Enum):
ASSET_ALIAS = "asset_alias"
ASSET_ANY = "asset_any"
ASSET_ALL = "asset_all"
ASSET_REF = "asset_ref"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
17 changes: 16 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@
{ "type": "integer" }
]
},
"asset_definition": {
"type": "object",
"properties": {
"uri": { "type": "string" },
"name": { "type": "string" },
"group": { "type": "string" },
"extra": {
"anyOf": [
{"type": "null"},
{ "$ref": "#/definitions/dict" }
]
}
},
"required": [ "uri", "extra" ]
},
"asset": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -153,7 +168,7 @@
"_processor_dags_folder": {
"anyOf": [
{ "type": "null" },
{"type": "string"}
{ "type": "string" }
]
},
"dag_display_name": { "type" : "string"},
Expand Down
Loading

0 comments on commit b3362f8

Please sign in to comment.