From a57ed8b96096c80dfa7bb9bea5c01b711fdd2474 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 13 Nov 2024 18:36:12 +0800 Subject: [PATCH] fix(assets): extend Asset as_expression methods to include name, group fields (also AssetAlias group field) --- airflow/assets/__init__.py | 10 ++++++---- airflow/serialization/serialized_objects.py | 2 +- airflow/timetables/simple.py | 4 +++- tests/api_fastapi/core_api/routes/ui/test_assets.py | 4 +++- tests/assets/test_asset.py | 10 +++++----- tests/www/views/test_views_grid.py | 4 +++- 6 files changed, 21 insertions(+), 13 deletions(-) diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index 59e0b8668449..12fd1af56566 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -333,7 +333,7 @@ def as_expression(self) -> Any: :meta private: """ - return self.uri + return {"asset": {"uri": self.uri, "name": self.name, "group": self.group}} def iter_assets(self) -> Iterator[tuple[str, Asset]]: yield self.uri, self @@ -380,7 +380,8 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = [ - _AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects + _AssetAliasCondition(name=obj.name, group=obj.group) if isinstance(obj, AssetAlias) else obj + for obj in objects ] def evaluate(self, statuses: dict[str, bool]) -> bool: @@ -440,8 +441,9 @@ class _AssetAliasCondition(AssetAny): :meta private: """ - def __init__(self, name: str) -> None: + def __init__(self, name: str, group: str) -> None: self.name = name + self.group = group self.objects = expand_alias_to_assets(name) def __repr__(self) -> str: @@ -453,7 +455,7 @@ def as_expression(self) -> Any: :meta private: """ - return {"alias": self.name} + return {"alias": {"name": self.name, "group": self.group}} def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: yield self.name, AssetAlias(self.name) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index c41f71126e2c..6e007e2529e4 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1048,7 +1048,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) elif isinstance(obj, AssetAlias): - cond = _AssetAliasCondition(obj.name) + cond = _AssetAliasCondition(name=obj.name, group=obj.group) deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) return deps diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index 5a931b40dd11..2dbefc7a4118 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -169,7 +169,9 @@ def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = _AssetAliasCondition(self.asset_condition.name) + self.asset_condition = _AssetAliasCondition( + name=self.asset_condition.name, group=self.asset_condition.group + ) if not next(self.asset_condition.iter_assets(), False): self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index b71d80ae9d31..ee19d841a876 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -46,6 +46,8 @@ def test_next_run_assets(test_client, dag_maker): assert response.status_code == 200 assert response.json() == { - "asset_expression": {"all": ["s3://bucket/key/1"]}, + "asset_expression": { + "all": [{"asset": {"uri": "s3://bucket/key/1", "name": "s3://bucket/key/1", "group": ""}}] + }, "events": [{"id": 17, "uri": "s3://bucket/key/1", "lastUpdate": None}], } diff --git a/tests/assets/test_asset.py b/tests/assets/test_asset.py index a454fd2826bd..e18ceed89c01 100644 --- a/tests/assets/test_asset.py +++ b/tests/assets/test_asset.py @@ -597,22 +597,22 @@ def resolved_asset_alias_2(self, session, asset_1): return asset_alias_2 def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2): - cond = _AssetAliasCondition(name=asset_alias_1.name) + cond = _AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) assert cond.objects == [] - cond = _AssetAliasCondition(name=resolved_asset_alias_2.name) + cond = _AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) assert cond.objects == [Asset(uri=asset_1.uri)] def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): for assset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = _AssetAliasCondition(assset_alias.name) + cond = _AssetAliasCondition(name=assset_alias.name, group=assset_alias.group) assert cond.as_expression() == {"alias": assset_alias.name} def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1): - cond = _AssetAliasCondition(asset_alias_1.name) + cond = _AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) assert cond.evaluate({asset_1.uri: True}) is False - cond = _AssetAliasCondition(resolved_asset_alias_2.name) + cond = _AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) assert cond.evaluate({asset_1.uri: True}) is True diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index f7c8fdcf2b0d..235f9b68096c 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -508,7 +508,9 @@ def test_next_run_assets(admin_client, dag_maker, session, app, monkeypatch): assert resp.status_code == 200, resp.json assert resp.json == { - "asset_expression": {"all": ["s3://bucket/key/1", "s3://bucket/key/2"]}, + "asset_expression": { + "all": [{"asset": {"uri": "s3://bucket/key/1", "name": "s3://bucket/key/2", "group": ""}}] + }, "events": [ {"id": asset1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, {"id": asset2_id, "uri": "s3://bucket/key/2", "lastUpdate": None},