Skip to content

Commit

Permalink
fix(scheduler_job_runner/asset): fix how asset dag warning is added (a…
Browse files Browse the repository at this point in the history
…pache#43873)

The correct logic is

1. Find the warning that should exist after this round
2. Delete the warnings that no longer needed
3. Update the warnings if they already exist and add new warnings if they do not yet exists
  • Loading branch information
Lee-W authored Nov 14, 2024
1 parent c84d356 commit c94715f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 23 deletions.
34 changes: 23 additions & 11 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from collections import Counter, defaultdict, deque
from datetime import timedelta
from functools import lru_cache, partial
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator

Expand Down Expand Up @@ -2107,44 +2108,55 @@ def _activate_referenced_assets(assets: Collection[AssetModel], *, session: Sess
active_name_to_uri: dict[str, str] = {name: uri for name, uri in active_assets}
active_uri_to_name: dict[str, str] = {uri: name for name, uri in active_assets}

def _generate_dag_warnings(offending: AssetModel, attr: str, value: str) -> Iterator[DagWarning]:
def _generate_warning_message(
offending: AssetModel, attr: str, value: str
) -> Iterator[tuple[str, str]]:
for ref in itertools.chain(offending.consuming_dags, offending.producing_tasks):
yield DagWarning(
dag_id=ref.dag_id,
warning_type=DagWarningType.ASSET_CONFLICT,
message=f"Cannot activate asset {offending}; {attr} is already associated to {value!r}",
yield (
ref.dag_id,
f"Cannot activate asset {offending}; {attr} is already associated to {value!r}",
)

def _activate_assets_generate_warnings() -> Iterator[DagWarning]:
def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
incoming_name_to_uri: dict[str, str] = {}
incoming_uri_to_name: dict[str, str] = {}
for asset in assets:
if (asset.name, asset.uri) in active_assets:
continue
existing_uri = active_name_to_uri.get(asset.name) or incoming_name_to_uri.get(asset.name)
if existing_uri is not None and existing_uri != asset.uri:
yield from _generate_dag_warnings(asset, "name", existing_uri)
yield from _generate_warning_message(asset, "name", existing_uri)
continue
existing_name = active_uri_to_name.get(asset.uri) or incoming_uri_to_name.get(asset.uri)
if existing_name is not None and existing_name != asset.name:
yield from _generate_dag_warnings(asset, "uri", existing_name)
yield from _generate_warning_message(asset, "uri", existing_name)
continue
incoming_name_to_uri[asset.name] = asset.uri
incoming_uri_to_name[asset.uri] = asset.name
session.add(AssetActive.for_asset(asset))

warnings_to_have = {w.dag_id: w for w in _activate_assets_generate_warnings()}
warnings_to_have = {
dag_id: DagWarning(
dag_id=dag_id,
warning_type=DagWarningType.ASSET_CONFLICT,
message="\n".join([message for _, message in group]),
)
for dag_id, group in groupby(
sorted(_activate_assets_generate_warnings()), key=operator.itemgetter(0)
)
}

session.execute(
delete(DagWarning).where(
DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
DagWarning.dag_id.in_(warnings_to_have),
DagWarning.dag_id.not_in(warnings_to_have),
)
)
existing_warned_dag_ids: set[str] = set(
session.scalars(
select(DagWarning.dag_id).where(
DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
DagWarning.dag_id.not_in(warnings_to_have),
DagWarning.dag_id.in_(warnings_to_have),
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dagwarning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class DagWarning(Base):
A table to store DAG warnings.
DAG warnings are problems that don't rise to the level of failing the DAG parse
but which users should nonetheless be warned about. These warnings are recorded
but which users should nonetheless be warned about. These warnings are recorded
when parsing DAG and displayed on the Webserver in a flash message.
"""

Expand Down
109 changes: 109 additions & 0 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning
from airflow.models.db_callback_request import DbCallbackRequest
from airflow.models.pool import Pool
from airflow.models.serialized_dag import SerializedDagModel
Expand Down Expand Up @@ -6285,6 +6286,114 @@ def test_misconfigured_dags_doesnt_crash_scheduler(self, session, dag_maker, cap
# Check if the second dagrun was created
assert DagRun.find(dag_id="testdag2", session=session)

def test_activate_referenced_assets_with_no_existing_warning(self, session):
dag_warnings = session.query(DagWarning).all()
assert dag_warnings == []

dag_id1 = "test_asset_dag1"
asset1_name = "asset1"
asset_extra = {"foo": "bar"}

asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra)
asset1_1 = Asset(name=asset1_name, uri="it's duplicate", extra=asset_extra)
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1])

DAG.bulk_write_to_db([dag1], session=session)

asset_models = session.scalars(select(AssetModel)).all()

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()

dag_warning = session.scalar(
select(DagWarning).where(
DagWarning.dag_id == dag_id1, DagWarning.warning_type == "asset conflict"
)
)
assert dag_warning.message == (
"Cannot activate asset AssetModel(name='asset1', uri=\"it's duplicate\", extra={'foo': 'bar'}); "
"name is already associated to 's3://bucket/key/1'"
)

def test_activate_referenced_assets_with_existing_warnings(self, session):
dag_ids = [f"test_asset_dag{i}" for i in range(1, 4)]
asset1_name = "asset1"
asset_extra = {"foo": "bar"}

session.add_all(
[
DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist")
for dag_id in dag_ids
]
)

asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra)
asset1_1 = Asset(name=asset1_name, uri="it's duplicate", extra=asset_extra)
asset1_2 = Asset(name=asset1_name, uri="it's duplicate 2", extra=asset_extra)
dag1 = DAG(dag_id=dag_ids[0], start_date=DEFAULT_DATE, schedule=[asset1, asset1_1])
dag2 = DAG(dag_id=dag_ids[1], start_date=DEFAULT_DATE)
dag3 = DAG(dag_id=dag_ids[2], start_date=DEFAULT_DATE, schedule=[asset1_2])

DAG.bulk_write_to_db([dag1, dag2, dag3], session=session)

asset_models = session.scalars(select(AssetModel)).all()

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()

dag_warning = session.scalar(
select(DagWarning).where(
DagWarning.dag_id == dag_ids[0], DagWarning.warning_type == "asset conflict"
)
)
assert dag_warning.message == (
"Cannot activate asset AssetModel(name='asset1', uri=\"it's duplicate\", extra={'foo': 'bar'}); "
"name is already associated to 's3://bucket/key/1'"
)

dag_warning = session.scalar(
select(DagWarning).where(
DagWarning.dag_id == dag_ids[1], DagWarning.warning_type == "asset conflict"
)
)
assert dag_warning is None

dag_warning = session.scalar(
select(DagWarning).where(
DagWarning.dag_id == dag_ids[2], DagWarning.warning_type == "asset conflict"
)
)
assert dag_warning.message == (
"Cannot activate asset AssetModel(name='asset1', uri=\"it's duplicate 2\", extra={'foo': 'bar'}); "
"name is already associated to 's3://bucket/key/1'"
)

def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(self, session):
dag_id = "test_asset_dag"
asset1_name = "asset1"
asset_extra = {"foo": "bar"}

session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist"))

schedule = [Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra)]
schedule.extend(
[Asset(name=asset1_name, uri=f"it's duplicate {i}", extra=asset_extra) for i in range(100)]
)
dag1 = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule=schedule)

DAG.bulk_write_to_db([dag1], session=session)

asset_models = session.scalars(select(AssetModel)).all()

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()

dag_warning = session.scalar(
select(DagWarning).where(DagWarning.dag_id == dag_id, DagWarning.warning_type == "asset conflict")
)
for i in range(100):
assert f"it's duplicate {i}" in dag_warning.message


@pytest.mark.need_serialized_dag
def test_schedule_dag_run_with_upstream_skip(dag_maker, session):
Expand Down
22 changes: 11 additions & 11 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,21 +845,21 @@ def test_bulk_write_to_db_assets(self):
dag_id2 = "test_asset_dag2"
task_id = "test_asset_task"
uri1 = "s3://asset/1"
d1 = Asset(uri1, extra={"not": "used"})
d2 = Asset("s3://asset/2")
d3 = Asset("s3://asset/3")
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[d1])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3])
a1 = Asset(uri1, extra={"not": "used"})
a2 = Asset("s3://asset/2")
a3 = Asset("s3://asset/3")
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[a1])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2, a3])
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag2, outlets=[Asset(uri1, extra={"should": "be used"})])
session = settings.Session()
dag1.clear()
DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
stored_assets = {x.uri: x for x in session.query(AssetModel).all()}
asset1_orm = stored_assets[d1.uri]
asset2_orm = stored_assets[d2.uri]
asset3_orm = stored_assets[d3.uri]
asset1_orm = stored_assets[a1.uri]
asset2_orm = stored_assets[a2.uri]
asset3_orm = stored_assets[a3.uri]
assert stored_assets[uri1].extra == {"should": "be used"}
assert [x.dag_id for x in asset1_orm.consuming_dags] == [dag_id1]
assert [(x.task_id, x.dag_id) for x in asset1_orm.producing_tasks] == [(task_id, dag_id2)]
Expand All @@ -882,15 +882,15 @@ def test_bulk_write_to_db_assets(self):
# so if any references are *removed*, they should also be deleted from the DB
# so let's remove some references and see what happens
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2])
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag2)
DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
session.expunge_all()
stored_assets = {x.uri: x for x in session.query(AssetModel).all()}
asset1_orm = stored_assets[d1.uri]
asset2_orm = stored_assets[d2.uri]
asset1_orm = stored_assets[a1.uri]
asset2_orm = stored_assets[a2.uri]
assert [x.dag_id for x in asset1_orm.consuming_dags] == []
assert set(
session.query(
Expand Down

0 comments on commit c94715f

Please sign in to comment.