diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index ce1f49e8adb7..15726d37a326 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -60,7 +60,6 @@ AssetAny, AssetRef, BaseAsset, - _AssetAliasCondition, ) from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding @@ -1054,7 +1053,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) elif isinstance(obj, AssetAlias): - cond = _AssetAliasCondition(obj.name) + cond = AssetAliasCondition(obj.name) deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) return deps diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index adba135c5785..8ce498c9e049 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Collection, Sequence -from airflow.sdk.definitions.asset import AssetAlias, _AssetAliasCondition +from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.utils import timezone @@ -169,7 +169,7 @@ def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = _AssetAliasCondition(self.asset_condition.name) + self.asset_condition = AssetAliasCondition(self.asset_condition.name) if not next(self.asset_condition.iter_assets(), False): self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py index f093ea68c8f4..8e0224a41770 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -405,7 +405,7 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = [ - _AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects + AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects ] def evaluate(self, statuses: dict[str, bool]) -> bool: @@ -458,7 +458,7 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} -class _AssetAliasCondition(AssetAny): +class AssetAliasCondition(AssetAny): """ Use to expand AssetAlias as AssetAny of its resolved Assets. @@ -470,7 +470,7 @@ def __init__(self, name: str) -> None: self.objects = expand_alias_to_assets(name) def __repr__(self) -> str: - return f"_AssetAliasCondition({', '.join(map(str, self.objects))})" + return f"AssetAliasCondition({', '.join(map(str, self.objects))})" def as_expression(self) -> Any: """ diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 11f067afa5da..2f3504f286de 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -37,7 +37,6 @@ BaseAsset, Dataset, Model, - _AssetAliasCondition, _get_normalized_scheme, _sanitize_uri, ) @@ -565,7 +564,7 @@ def test_normalize_uri_valid_uri(): @pytest.mark.skip_if_database_isolation_mode @pytest.mark.db_test @pytest.mark.usefixtures("clear_assets") -class Test_AssetAliasCondition: +class TestAssetAliasCondition: @pytest.fixture def asset_1(self, session): """Example asset links to asset alias resolved_asset_alias_2.""" @@ -601,22 +600,22 @@ def resolved_asset_alias_2(self, session, asset_1): return asset_alias_2 def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2): - cond = _AssetAliasCondition(name=asset_alias_1.name) + cond = AssetAliasCondition(name=asset_alias_1.name) assert cond.objects == [] - cond = _AssetAliasCondition(name=resolved_asset_alias_2.name) + cond = AssetAliasCondition(name=resolved_asset_alias_2.name) assert cond.objects == [Asset(uri=asset_1.uri)] def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): for assset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = _AssetAliasCondition(assset_alias.name) + cond = AssetAliasCondition(assset_alias.name) assert cond.as_expression() == {"alias": assset_alias.name} def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1): - cond = _AssetAliasCondition(asset_alias_1.name) + cond = AssetAliasCondition(asset_alias_1.name) assert cond.evaluate({asset_1.uri: True}) is False - cond = _AssetAliasCondition(resolved_asset_alias_2.name) + cond = AssetAliasCondition(resolved_asset_alias_2.name) assert cond.evaluate({asset_1.uri: True}) is True