Skip to content

Commit

Permalink
test(models/test_dag): extend asset test cases to cover name, uri, group
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Nov 14, 2024
1 parent 54105cf commit cd30368
Showing 1 changed file with 70 additions and 30 deletions.
100 changes: 70 additions & 30 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,23 +843,32 @@ def test_bulk_write_to_db_assets(self):
"""
dag_id1 = "test_asset_dag1"
dag_id2 = "test_asset_dag2"

task_id = "test_asset_task"

uri1 = "s3://asset/1"
d1 = Asset(uri1, extra={"not": "used"})
d2 = Asset("s3://asset/2")
d3 = Asset("s3://asset/3")
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[d1])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3])
a1 = Asset(uri=uri1, name="test_asset_1", extra={"not": "used"}, group="test-group")
a2 = Asset(uri="s3://asset/2", name="test_asset_2", group="test-group")
a3 = Asset(uri="s3://asset/3", name="test_asset-3", group="test-group")

dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[a1])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2, a3])

dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag2, outlets=[Asset(uri1, extra={"should": "be used"})])
EmptyOperator(
task_id=task_id,
dag=dag2,
outlets=[Asset(uri=uri1, name="test_asset_1", extra={"should": "be used"}, group="test-group")],
)

session = settings.Session()
dag1.clear()
DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
stored_assets = {x.uri: x for x in session.query(AssetModel).all()}
asset1_orm = stored_assets[d1.uri]
asset2_orm = stored_assets[d2.uri]
asset3_orm = stored_assets[d3.uri]
asset1_orm = stored_assets[a1.uri]
asset2_orm = stored_assets[a2.uri]
asset3_orm = stored_assets[a3.uri]
assert stored_assets[uri1].extra == {"should": "be used"}
assert [x.dag_id for x in asset1_orm.consuming_dags] == [dag_id1]
assert [(x.task_id, x.dag_id) for x in asset1_orm.producing_tasks] == [(task_id, dag_id2)]
Expand All @@ -882,15 +891,15 @@ def test_bulk_write_to_db_assets(self):
# so if any references are *removed*, they should also be deleted from the DB
# so let's remove some references and see what happens
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2])
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag2)
DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
session.expunge_all()
stored_assets = {x.uri: x for x in session.query(AssetModel).all()}
asset1_orm = stored_assets[d1.uri]
asset2_orm = stored_assets[d2.uri]
asset1_orm = stored_assets[a1.uri]
asset2_orm = stored_assets[a2.uri]
assert [x.dag_id for x in asset1_orm.consuming_dags] == []
assert set(
session.query(
Expand Down Expand Up @@ -920,10 +929,10 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session):
"""
# Create four assets - two that have references and two that are unreferenced and marked as
# orphans
asset1 = Asset(uri="ds1")
asset2 = Asset(uri="ds2")
asset3 = Asset(uri="ds3")
asset4 = Asset(uri="ds4")
asset1 = Asset(uri="test://asset1", name="asset1", group="test-group")
asset2 = Asset(uri="test://asset2", name="asset2", group="test-group")
asset3 = Asset(uri="test://asset3", name="asset3", group="test-group")
asset4 = Asset(uri="test://asset4", name="asset4", group="test-group")

dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1])
BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3])
Expand Down Expand Up @@ -1393,8 +1402,11 @@ def test_timetable_and_description_from_schedule_arg(
assert dag.timetable.description == interval_description

def test_timetable_and_description_from_asset(self):
dag = DAG("test_schedule_interval_arg", schedule=[Asset(uri="hello")], start_date=TEST_DATE)
assert dag.timetable == AssetTriggeredTimetable(Asset(uri="hello"))
uri = "test://asset"
dag = DAG(
"test_schedule_interval_arg", schedule=[Asset(uri=uri, group="test-group")], start_date=TEST_DATE
)
assert dag.timetable == AssetTriggeredTimetable(Asset(uri=uri, group="test-group"))
assert dag.timetable.description == "Triggered by assets"

@pytest.mark.parametrize(
Expand Down Expand Up @@ -2159,7 +2171,7 @@ def test_dags_needing_dagruns_not_too_early(self):
session.close()

def test_dags_needing_dagruns_assets(self, dag_maker, session):
asset = Asset(uri="hello")
asset = Asset(uri="test://asset", group="test-group")
with dag_maker(
session=session,
dag_id="my_dag",
Expand Down Expand Up @@ -2391,8 +2403,8 @@ def test__processor_dags_folder(self, session):

@pytest.mark.need_serialized_dag
def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, session, dag_maker):
asset1 = Asset(uri="ds1")
asset2 = Asset(uri="ds2")
asset1 = Asset(uri="test://asset1", group="test-group")
asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test-group")

for dag_id, asset in [("assets-1", asset1), ("assets-2", asset2)]:
with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(), session=session):
Expand Down Expand Up @@ -2441,10 +2453,15 @@ def test_asset_expression(self, session: Session) -> None:
dag = DAG(
dag_id="test_dag_asset_expression",
schedule=AssetAny(
Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}),
Asset(uri="s3://dag1/output_1.txt", extra={"hi": "bye"}, group="test-group"),
AssetAll(
Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}),
Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}),
Asset(
uri="s3://dag2/output_1.txt",
name="test_asset_2",
extra={"hi": "bye"},
group="test-group",
),
Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}, group="test-group"),
),
AssetAlias(name="test_name"),
),
Expand All @@ -2455,9 +2472,32 @@ def test_asset_expression(self, session: Session) -> None:
expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one()
assert expression == {
"any": [
"s3://dag1/output_1.txt",
{"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]},
{"alias": "test_name"},
{
"asset": {
"uri": "s3://dag1/output_1.txt",
"name": "s3://dag1/output_1.txt",
"group": "test-group",
}
},
{
"all": [
{
"asset": {
"uri": "s3://dag2/output_1.txt",
"name": "test_asset_2",
"group": "test-group",
}
},
{
"asset": {
"uri": "s3://dag3/output_3.txt",
"name": "s3://dag3/output_3.txt",
"group": "test-group",
}
},
]
},
{"alias": {"name": "test_name", "group": ""}},
]
}

Expand Down Expand Up @@ -3022,9 +3062,9 @@ def test__time_restriction(dag_maker, dag_date, tasks_date, restrict):

@pytest.mark.need_serialized_dag
def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
asset1 = Asset(uri="ds1")
asset2 = Asset(uri="ds2")
asset3 = Asset(uri="ds3")
asset1 = Asset(uri="test://asset1", name="test_asset1", group="test-group")
asset2 = Asset(uri="test://asset2", group="test-group")
asset3 = Asset(uri="test://asset3", group="test-group")
with dag_maker(dag_id="assets-1", schedule=[asset2]):
pass
dag1 = dag_maker.dag
Expand Down

0 comments on commit cd30368

Please sign in to comment.