Skip to content

Commit

Permalink
refactor(asset): rename _AssetAliasCondition as AssetAliasCondition
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Nov 15, 2024
1 parent 0521112 commit 757f6a6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
3 changes: 1 addition & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
AssetAny,
AssetRef,
BaseAsset,
_AssetAliasCondition,
)
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
Expand Down Expand Up @@ -1054,7 +1053,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]:
)
)
elif isinstance(obj, AssetAlias):
cond = _AssetAliasCondition(obj.name)
cond = AssetAliasCondition(obj.name)

deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target=""))
return deps
Expand Down
4 changes: 2 additions & 2 deletions airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING, Any, Collection, Sequence

from airflow.sdk.definitions.asset import AssetAlias, _AssetAliasCondition
from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
from airflow.utils import timezone

Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, assets: BaseAsset) -> None:
super().__init__()
self.asset_condition = assets
if isinstance(self.asset_condition, AssetAlias):
self.asset_condition = _AssetAliasCondition(self.asset_condition.name)
self.asset_condition = AssetAliasCondition(self.asset_condition.name)

if not next(self.asset_condition.iter_assets(), False):
self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __init__(self, *objects: BaseAsset) -> None:
raise TypeError("expect asset expressions in condition")

self.objects = [
_AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects
AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects
]

def evaluate(self, statuses: dict[str, bool]) -> bool:
Expand Down Expand Up @@ -458,7 +458,7 @@ def as_expression(self) -> dict[str, Any]:
return {"any": [o.as_expression() for o in self.objects]}


class _AssetAliasCondition(AssetAny):
class AssetAliasCondition(AssetAny):
"""
Use to expand AssetAlias as AssetAny of its resolved Assets.
Expand All @@ -470,7 +470,7 @@ def __init__(self, name: str) -> None:
self.objects = expand_alias_to_assets(name)

def __repr__(self) -> str:
return f"_AssetAliasCondition({', '.join(map(str, self.objects))})"
return f"AssetAliasCondition({', '.join(map(str, self.objects))})"

def as_expression(self) -> Any:
"""
Expand Down
13 changes: 6 additions & 7 deletions task_sdk/tests/defintions/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
BaseAsset,
Dataset,
Model,
_AssetAliasCondition,
_get_normalized_scheme,
_sanitize_uri,
)
Expand Down Expand Up @@ -565,7 +564,7 @@ def test_normalize_uri_valid_uri():
@pytest.mark.skip_if_database_isolation_mode
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_assets")
class Test_AssetAliasCondition:
class TestAssetAliasCondition:
@pytest.fixture
def asset_1(self, session):
"""Example asset links to asset alias resolved_asset_alias_2."""
Expand Down Expand Up @@ -601,22 +600,22 @@ def resolved_asset_alias_2(self, session, asset_1):
return asset_alias_2

def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2):
cond = _AssetAliasCondition(name=asset_alias_1.name)
cond = AssetAliasCondition(name=asset_alias_1.name)
assert cond.objects == []

cond = _AssetAliasCondition(name=resolved_asset_alias_2.name)
cond = AssetAliasCondition(name=resolved_asset_alias_2.name)
assert cond.objects == [Asset(uri=asset_1.uri)]

def test_as_expression(self, asset_alias_1, resolved_asset_alias_2):
for assset_alias in (asset_alias_1, resolved_asset_alias_2):
cond = _AssetAliasCondition(assset_alias.name)
cond = AssetAliasCondition(assset_alias.name)
assert cond.as_expression() == {"alias": assset_alias.name}

def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1):
cond = _AssetAliasCondition(asset_alias_1.name)
cond = AssetAliasCondition(asset_alias_1.name)
assert cond.evaluate({asset_1.uri: True}) is False

cond = _AssetAliasCondition(resolved_asset_alias_2.name)
cond = AssetAliasCondition(resolved_asset_alias_2.name)
assert cond.evaluate({asset_1.uri: True}) is True


Expand Down

0 comments on commit 757f6a6

Please sign in to comment.