From b3362f841f31b5629f09f619cfa2d69e7199493a Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 14 Nov 2024 17:14:45 +0800 Subject: [PATCH] Add "@asset" to decorate a function as a DAG and an asset (#41325) * Implement asset definition creating a DAG * Basic inlet dependency * Make AssetDefinition subclass Asset This seems to be the best way for 'schedule' dependencies to work. Still not entirely sure; we'll revisit this. * style: fix mypy error * feat(asset): allow uri to be None * fix: temporarily serialize AssetDefintion into a string * feat(decorators/assets): rewrite how asset definition is serialized * test(decorators/assets): add test cases to check whether asset decorator generate the right asset definition * test(decorators/assets): add test cases to AssetDefinition * test(decorators/asset): add test cases to Test_AssetMainOperator * test(decorators/assets): remove unused fixtures * docs(example_dag): add example dag for asset_decorator * feat(decorators/assets): allow passing self and context into asset * feat(decorators/assets): return actual asset in asset decorator * refactor(decorators/assets): extract active assets fetching logic as _fetch_active_assets_by_name * feat(decorators/assets): allow fethcing inlet events through AssetRef * feat(decorators/assets): reorder import paths * docs: update asset decorator example dag * test: fix tests * test(decorators/assets): extend test_determine_kwargs to cover active asset * fix: address easy to fix comments * fix: fix asset serialization * refactor(decorators/assets): postpone the attribute check to AssetDefinition instead of asset decorator * Simplify group validators The validate_identifier validator already checks the length, so we don't need an extra one doing that. * style(dag): remove _wrapped_definition * style(decorators/assets): change types.FunctionType to Callable * refactor(decorators/assets): make session in _fetch_active_assets_by_name required * fix(decorators/asets): remove DAG.bulk_write_to_db and remove self handling * feat(utils/context): fetch asset_refs all at once --------- Co-authored-by: Wei Lee --- airflow/assets/__init__.py | 34 ++-- airflow/decorators/assets.py | 131 +++++++++++++ .../example_dags/example_asset_decorator.py | 52 +++++ airflow/models/asset.py | 22 +++ airflow/models/dag.py | 4 +- airflow/serialization/enums.py | 1 + airflow/serialization/schema.json | 17 +- airflow/serialization/serialized_objects.py | 9 +- airflow/utils/context.py | 21 ++- airflow/utils/file.py | 4 +- .../core_api/routes/ui/test_assets.py | 2 +- tests/decorators/test_assets.py | 177 ++++++++++++++++++ tests/timetables/test_assets_timetable.py | 4 +- 13 files changed, 451 insertions(+), 27 deletions(-) create mode 100644 airflow/decorators/assets.py create mode 100644 airflow/example_dags/example_asset_decorator.py create mode 100644 tests/decorators/test_assets.py diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index 59e0b8668449..f1d36ac12b73 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -23,7 +23,7 @@ import warnings from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, cast, overload -import attr +import attrs from sqlalchemy import select from airflow.api_internal.internal_api_call import internal_api_call @@ -123,6 +123,13 @@ def _validate_non_empty_identifier(instance, attribute, value): return value +def _validate_asset_name(instance, attribute, value): + _validate_non_empty_identifier(instance, attribute, value) + if value == "self" or value == "context": + raise ValueError(f"prohibited name for asset: {value}") + return value + + def extract_event_key(value: str | Asset | AssetAlias) -> str: """ Extract the key of an inlet or an outlet event. @@ -158,6 +165,13 @@ def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SE return [] +@attrs.define(kw_only=True) +class AssetRef: + """Reference to an asset.""" + + name: str + + class BaseAsset: """ Protocol for all asset triggers to use in ``DAG(schedule=...)``. @@ -207,16 +221,12 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe raise NotImplementedError -@attr.define(unsafe_hash=False) +@attrs.define(unsafe_hash=False) class AssetAlias(BaseAsset): """A represeation of asset alias which is used to create asset during the runtime.""" - name: str = attr.field(validator=_validate_non_empty_identifier) - group: str = attr.field( - kw_only=True, - default="", - validator=[attr.validators.max_len(1500), _validate_identifier], - ) + name: str = attrs.field(validator=_validate_non_empty_identifier) + group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) def iter_assets(self) -> Iterator[tuple[str, Asset]]: return iter(()) @@ -258,7 +268,7 @@ def _set_extra_default(extra: dict | None) -> dict: return extra -@attr.define(init=False, unsafe_hash=False) +@attrs.define(init=False, unsafe_hash=False) class Asset(os.PathLike, BaseAsset): """A representation of data asset dependencies between workflows.""" @@ -267,7 +277,7 @@ class Asset(os.PathLike, BaseAsset): group: str extra: dict[str, Any] - asset_type: ClassVar[str] = "" + asset_type: ClassVar[str] = "asset" __version__: ClassVar[int] = 1 @overload @@ -296,8 +306,8 @@ def __init__( name = uri elif uri is None: uri = name - fields = attr.fields_dict(Asset) - self.name = _validate_non_empty_identifier(self, fields["name"], name) + fields = attrs.fields_dict(Asset) + self.name = _validate_asset_name(self, fields["name"], name) self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri)) self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type self.extra = _set_extra_default(extra) diff --git a/airflow/decorators/assets.py b/airflow/decorators/assets.py new file mode 100644 index 000000000000..2f5052c2d5c9 --- /dev/null +++ b/airflow/decorators/assets.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping + +import attrs + +from airflow.assets import Asset, AssetRef +from airflow.models.asset import _fetch_active_assets_by_name +from airflow.models.dag import DAG, ScheduleArg +from airflow.providers.standard.operators.python import PythonOperator +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from airflow.io.path import ObjectStoragePath + + +class _AssetMainOperator(PythonOperator): + def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self._definition_name = definition_name + self._uri = uri + + def _iter_kwargs( + self, context: Mapping[str, Any], active_assets: dict[str, Asset] + ) -> Iterator[tuple[str, Any]]: + value: Any + for key in inspect.signature(self.python_callable).parameters: + if key == "self": + value = active_assets.get(self._definition_name) + elif key == "context": + value = context + else: + value = active_assets.get(key, Asset(name=key)) + yield key, value + + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: + active_assets: dict[str, Asset] = {} + asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] + if "self" in inspect.signature(self.python_callable).parameters: + asset_names.append(self._definition_name) + + if asset_names: + with create_session() as session: + active_assets = _fetch_active_assets_by_name(asset_names, session) + return dict(self._iter_kwargs(context, active_assets)) + + +@attrs.define(kw_only=True) +class AssetDefinition(Asset): + """ + Asset representation from decorating a function with ``@asset``. + + :meta private: + """ + + function: Callable + schedule: ScheduleArg + + def __attrs_post_init__(self) -> None: + parameters = inspect.signature(self.function).parameters + + with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): + _AssetMainOperator( + task_id="__main__", + inlets=[ + AssetRef(name=inlet_asset_name) + for inlet_asset_name in parameters + if inlet_asset_name not in ("self", "context") + ], + outlets=[self.to_asset()], + python_callable=self.function, + definition_name=self.name, + uri=self.uri, + ) + + def to_asset(self) -> Asset: + return Asset( + name=self.name, + uri=self.uri, + group=self.group, + extra=self.extra, + ) + + def serialize(self): + return { + "uri": self.uri, + "name": self.name, + "group": self.group, + "extra": self.extra, + } + + +@attrs.define(kw_only=True) +class asset: + """Create an asset by decorating a materialization function.""" + + schedule: ScheduleArg + uri: str | ObjectStoragePath | None = None + group: str = "" + extra: dict[str, Any] = attrs.field(factory=dict) + + def __call__(self, f: Callable) -> AssetDefinition: + if (name := f.__name__) != f.__qualname__: + raise ValueError("nested function not supported") + + return AssetDefinition( + name=name, + uri=name if self.uri is None else str(self.uri), + group=self.group, + extra=self.extra, + function=f, + schedule=self.schedule, + ) diff --git a/airflow/example_dags/example_asset_decorator.py b/airflow/example_dags/example_asset_decorator.py new file mode 100644 index 000000000000..b4de09c23146 --- /dev/null +++ b/airflow/example_dags/example_asset_decorator.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pendulum + +from airflow.assets import Asset +from airflow.decorators import dag, task +from airflow.decorators.assets import asset + + +@asset(uri="s3://bucket/asset1_producer", schedule=None) +def asset1_producer(): + pass + + +@asset(uri="s3://bucket/object", schedule=None) +def asset2_producer(self, context, asset1_producer): + print(self) + print(context["inlet_events"][asset1_producer]) + + +@dag( + schedule=Asset(uri="s3://bucket/asset1_producer", name="asset1_producer") + | Asset(uri="s3://bucket/object", name="asset2_producer"), + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["consumes", "asset-scheduled"], +) +def consumes_asset_decorator(): + @task(outlets=[Asset(name="process_nothing")]) + def process_nothing(): + pass + + process_nothing() + + +consumes_asset_decorator() diff --git a/airflow/models/asset.py b/airflow/models/asset.py index 8ade71bd0b1f..50914d51650b 100644 --- a/airflow/models/asset.py +++ b/airflow/models/asset.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING from urllib.parse import urlsplit import sqlalchemy_jsonfield @@ -29,6 +30,7 @@ PrimaryKeyConstraint, String, Table, + select, text, ) from sqlalchemy.orm import relationship @@ -39,6 +41,26 @@ from airflow.utils import timezone from airflow.utils.sqlalchemy import UtcDateTime +if TYPE_CHECKING: + from typing import Sequence + + from sqlalchemy.orm import Session + + +def _fetch_active_assets_by_name( + names: Sequence[str], + session: Session, +) -> dict[str, Asset]: + return { + asset_model[0].name: asset_model[0].to_public() + for asset_model in session.execute( + select(AssetModel) + .join(AssetActive, AssetActive.name == AssetModel.name) + .where(AssetActive.name.in_(name for name in names)) + ) + } + + alias_association_table = Table( "asset_alias_asset", Base.metadata, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e6a67c6ad7e5..e48ec0a9a9c5 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -777,9 +777,7 @@ def get_is_paused(self, session=NEW_SESSION) -> None: @classmethod def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" - return TaskSDKDag.get_serialized_fields() | { - "_processor_dags_folder", - } + return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"} @staticmethod @internal_api_call diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index dd63366b8a95..d1b946b38ef1 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -59,6 +59,7 @@ class DagAttributeTypes(str, Enum): ASSET_ALIAS = "asset_alias" ASSET_ANY = "asset_any" ASSET_ALL = "asset_all" + ASSET_REF = "asset_ref" SIMPLE_TASK_INSTANCE = "simple_task_instance" BASE_JOB = "Job" TASK_INSTANCE = "task_instance" diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index b26b59339816..1e7232aa81a9 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -53,6 +53,21 @@ { "type": "integer" } ] }, + "asset_definition": { + "type": "object", + "properties": { + "uri": { "type": "string" }, + "name": { "type": "string" }, + "group": { "type": "string" }, + "extra": { + "anyOf": [ + {"type": "null"}, + { "$ref": "#/definitions/dict" } + ] + } + }, + "required": [ "uri", "extra" ] + }, "asset": { "type": "object", "properties": { @@ -153,7 +168,7 @@ "_processor_dags_folder": { "anyOf": [ { "type": "null" }, - {"type": "string"} + { "type": "string" } ] }, "dag_display_name": { "type" : "string"}, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 52b0bcb1530a..4b7ee6d0871b 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -40,6 +40,7 @@ AssetAlias, AssetAll, AssetAny, + AssetRef, BaseAsset, _AssetAliasCondition, ) @@ -254,7 +255,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: :meta private: """ if isinstance(var, Asset): - return {"__type": DAT.ASSET, "uri": var.uri, "extra": var.extra} + return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "extra": var.extra} if isinstance(var, AssetAlias): return {"__type": DAT.ASSET_ALIAS, "name": var.name} if isinstance(var, AssetAll): @@ -272,7 +273,7 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: """ dat = var["__type"] if dat == DAT.ASSET: - return Asset(var["uri"], extra=var["extra"]) + return Asset(uri=var["uri"], name=var["name"], extra=var["extra"]) if dat == DAT.ASSET_ALL: return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) if dat == DAT.ASSET_ANY: @@ -743,6 +744,8 @@ def serialize( elif isinstance(var, BaseAsset): serialized_asset = encode_asset_condition(var) return cls._encode(serialized_asset, type_=serialized_asset.pop("__type")) + elif isinstance(var, AssetRef): + return cls._encode({"name": var.name}, type_=DAT.ASSET_REF) elif isinstance(var, SimpleTaskInstance): return cls._encode( cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models), @@ -876,6 +879,8 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) elif type_ == DAT.ASSET_ALL: return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) + elif type_ == DAT.ASSET_REF: + return AssetRef(name=var["name"]) elif type_ == DAT.SIMPLE_TASK_INSTANCE: return SimpleTaskInstance(**cls.deserialize(var)) elif type_ == DAT.CONNECTION: diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 9af1486f914e..b28559999e50 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -44,10 +44,11 @@ Asset, AssetAlias, AssetAliasEvent, + AssetRef, extract_event_key, ) from airflow.exceptions import RemovedInAirflow3Warning -from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel +from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, _fetch_active_assets_by_name from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET @@ -257,11 +258,18 @@ def __init__(self, inlets: list, *, session: Session) -> None: self._assets = {} self._asset_aliases = {} + _asset_ref_names: list[str] = [] for inlet in inlets: if isinstance(inlet, Asset): - self._assets[inlet.uri] = inlet + self._assets[inlet.name] = inlet elif isinstance(inlet, AssetAlias): self._asset_aliases[inlet.name] = inlet + elif isinstance(inlet, AssetRef): + _asset_ref_names.append(inlet.name) + + if _asset_ref_names: + for asset_name, asset in _fetch_active_assets_by_name(_asset_ref_names, self._session).items(): + self._assets[asset_name] = asset def __iter__(self) -> Iterator[str]: return iter(self._inlets) @@ -272,7 +280,7 @@ def __len__(self) -> int: def __getitem__(self, key: int | str | Asset | AssetAlias) -> LazyAssetEventSelectSequence: if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] - if not isinstance(obj, (Asset, AssetAlias)): + if not isinstance(obj, (Asset, AssetAlias, AssetRef)): raise IndexError(key) else: obj = key @@ -281,10 +289,13 @@ def __getitem__(self, key: int | str | Asset | AssetAlias) -> LazyAssetEventSele asset_alias = self._asset_aliases[obj.name] join_clause = AssetEvent.source_aliases where_clause = AssetAliasModel.name == asset_alias.name - elif isinstance(obj, (Asset, str)): + elif isinstance(obj, (Asset, AssetRef)): + join_clause = AssetEvent.asset + where_clause = AssetModel.name == self._assets[obj.name].name + elif isinstance(obj, str): asset = self._assets[extract_event_key(obj)] join_clause = AssetEvent.asset - where_clause = AssetModel.uri == asset.uri + where_clause = AssetModel.name == asset.name else: raise ValueError(key) diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 86b7a7891ca8..962f97c8fcfa 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -328,7 +328,9 @@ def might_contain_dag_via_default_heuristic(file_path: str, zip_file: zipfile.Zi with open(file_path, "rb") as dag_file: content = dag_file.read() content = content.lower() - return all(s in content for s in (b"dag", b"airflow")) + if b"airflow" not in content: + return False + return any(s in content for s in (b"dag", b"asset")) def _find_imported_modules(module: ast.Module) -> Generator[str, None, None]: diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index b71d80ae9d31..b5c85b98ba6b 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -47,5 +47,5 @@ def test_next_run_assets(test_client, dag_maker): assert response.status_code == 200 assert response.json() == { "asset_expression": {"all": ["s3://bucket/key/1"]}, - "events": [{"id": 17, "uri": "s3://bucket/key/1", "lastUpdate": None}], + "events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}], } diff --git a/tests/decorators/test_assets.py b/tests/decorators/test_assets.py new file mode 100644 index 000000000000..a3821140e548 --- /dev/null +++ b/tests/decorators/test_assets.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import ANY + +import pytest + +from airflow.assets import Asset +from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset +from airflow.models.asset import AssetActive, AssetModel + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def example_asset_func(request): + name = "example_asset_func" + if getattr(request, "param", None) is not None: + name = request.param + + def _example_asset_func(): + return "This is example_asset" + + _example_asset_func.__name__ = name + _example_asset_func.__qualname__ = name + return _example_asset_func + + +@pytest.fixture +def example_asset_definition(example_asset_func): + return asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func + ) + + +@pytest.fixture +def example_asset_func_with_valid_arg_as_inlet_asset(): + def _example_asset_func(self, context, inlet_asset_1, inlet_asset_2): + return "This is example_asset" + + _example_asset_func.__name__ = "example_asset_func" + _example_asset_func.__qualname__ = "example_asset_func" + return _example_asset_func + + +class TestAssetDecorator: + def test_without_uri(self, example_asset_func): + asset_definition = asset(schedule=None)(example_asset_func) + + assert asset_definition.name == "example_asset_func" + assert asset_definition.uri == "example_asset_func" + assert asset_definition.group == "" + assert asset_definition.extra == {} + assert asset_definition.function == example_asset_func + assert asset_definition.schedule is None + + def test_with_uri(self, example_asset_func): + asset_definition = asset(schedule=None, uri="s3://bucket/object")(example_asset_func) + + assert asset_definition.name == "example_asset_func" + assert asset_definition.uri == "s3://bucket/object" + assert asset_definition.group == "" + assert asset_definition.extra == {} + assert asset_definition.function == example_asset_func + assert asset_definition.schedule is None + + def test_with_group_and_extra(self, example_asset_func): + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func + ) + assert asset_definition.name == "example_asset_func" + assert asset_definition.uri == "s3://bucket/object" + assert asset_definition.group == "MLModel" + assert asset_definition.extra == {"k": "v"} + assert asset_definition.function == example_asset_func + assert asset_definition.schedule is None + + def test_nested_function(self): + def root_func(): + @asset(schedule=None) + def asset_func(): + pass + + with pytest.raises(ValueError) as err: + root_func() + + assert err.value.args[0] == "nested function not supported" + + @pytest.mark.parametrize("example_asset_func", ("self", "context"), indirect=True) + def test_with_invalid_asset_name(self, example_asset_func): + with pytest.raises(ValueError) as err: + asset(schedule=None)(example_asset_func) + + assert err.value.args[0].startswith("prohibited name for asset: ") + + +class TestAssetDefinition: + def test_serialzie(self, example_asset_definition): + assert example_asset_definition.serialize() == { + "extra": {"k": "v"}, + "group": "MLModel", + "name": "example_asset_func", + "uri": "s3://bucket/object", + } + + @mock.patch("airflow.decorators.assets._AssetMainOperator") + @mock.patch("airflow.decorators.assets.DAG") + def test__attrs_post_init__( + self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset + ): + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func_with_valid_arg_as_inlet_asset + ) + + DAG.assert_called_once_with(dag_id="example_asset_func", schedule=None, auto_register=True) + _AssetMainOperator.assert_called_once_with( + task_id="__main__", + inlets=[ + AssetRef(name="inlet_asset_1"), + AssetRef(name="inlet_asset_2"), + ], + outlets=[asset_definition.to_asset()], + python_callable=ANY, + definition_name="example_asset_func", + uri="s3://bucket/object", + ) + + python_callable = _AssetMainOperator.call_args.kwargs["python_callable"] + assert python_callable == example_asset_func_with_valid_arg_as_inlet_asset + + +class Test_AssetMainOperator: + def test_determine_kwargs(self, example_asset_func_with_valid_arg_as_inlet_asset, session): + example_asset_model = AssetModel(uri="s3://bucket/object1", name="inlet_asset_1") + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func_with_valid_arg_as_inlet_asset + ) + + ad_asset_model = AssetModel.from_public(asset_definition) + + session.add(example_asset_model) + session.add(ad_asset_model) + session.add(AssetActive.for_asset(example_asset_model)) + session.add(AssetActive.for_asset(ad_asset_model)) + session.commit() + + op = _AssetMainOperator( + task_id="__main__", + inlets=[AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")], + outlets=[asset_definition], + python_callable=example_asset_func_with_valid_arg_as_inlet_asset, + definition_name="example_asset_func", + ) + assert op.determine_kwargs(context={"k": "v"}) == { + "self": Asset( + name="example_asset_func", uri="s3://bucket/object", group="MLModel", extra={"k": "v"} + ), + "context": {"k": "v"}, + "inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"), + "inlet_asset_2": Asset(name="inlet_asset_2"), + } diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index f461afa31c89..bb942a4a01d4 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -134,7 +134,7 @@ def test_serialization(asset_timetable: AssetOrTimeSchedule, monkeypatch: Any) - "timetable": "mock_serialized_timetable", "asset_condition": { "__type": "asset_all", - "objects": [{"__type": "asset", "uri": "test_asset", "extra": {}}], + "objects": [{"__type": "asset", "uri": "test_asset", "name": "test_asset", "extra": {}}], }, } @@ -152,7 +152,7 @@ def test_deserialization(monkeypatch: Any) -> None: "timetable": "mock_serialized_timetable", "asset_condition": { "__type": "asset_all", - "objects": [{"__type": "asset", "uri": "test_asset", "extra": None}], + "objects": [{"__type": "asset", "name": "test_asset", "uri": "test_asset", "extra": None}], }, } deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data)