Skip to content

Commit

Permalink
Make Run status information more clear (#29)
Browse files Browse the repository at this point in the history
* rename test folder to avoid conflict with airflow module name

* fix: make the failure message more clear
  • Loading branch information
zongsizhang authored Feb 20, 2025
1 parent 2f0c689 commit 33128af
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 7 deletions.
11 changes: 7 additions & 4 deletions airflow_providers_wherobots/operators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,13 @@ def execute(self, context) -> Any:
run = self._wait_run_simple(rest_api_hook, run)
# loop end, means run is in terminal state
self._log_run_status(run)
if run.status in (RunStatus.FAILED, RunStatus.CANCELLED):
raise RuntimeError(
f"Run {run.ext_id} failed or cancelled by another party"
)
if run.status == RunStatus.FAILED:
# check events, see if the run is timeout
if run.is_timeout:
raise RuntimeError(f"Run {run.ext_id} failed due to timeout")
raise RuntimeError(f"Run {run.ext_id} failed, please check the logs")
if run.status == RunStatus.CANCELLED:
raise RuntimeError(f"Run {run.ext_id} was cancelled by user")

def on_kill(self) -> None:
"""
Expand Down
23 changes: 23 additions & 0 deletions airflow_providers_wherobots/wherobots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,34 @@ class WherobotsModel(BaseModel):
model_config = ConfigDict(from_attributes=True, populate_by_name=True)


class KubeAppEvent(BaseModel):
code: str
message: Optional[str] = None


class KubeApp(BaseModel):
events: List[KubeAppEvent]


class Run(WherobotsModel):
name: str
status: RunStatus
start_time: Optional[datetime] = Field(default=None, alias="startTime")
end_time: Optional[datetime] = Field(default=None, alias="completeTime")
kube_app: Optional[KubeApp] = Field(default=None, alias="kubeApp")

@property
def is_timeout(self) -> bool:
if not self.kube_app or not self.kube_app.events:
return False
return any(
(
event.code == "RUN_FAIL_EXEC"
and event.message
and "timeout" in event.message.lower()
)
for event in self.kube_app.events
)


class LogItem(BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions tests/integration_tests/operator/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,24 @@ def test_prod_run_success(prod_conn: Connection, dag: DAG) -> None:
ti = build_ti(dag, task_id=operator.task_id)
ti.run(ignore_ti_state=True)
assert ti.state == TaskInstanceState.SUCCESS


@pytest.mark.usefixtures("clean_airflow_db")
def test_prod_run_timeout(prod_conn: Connection, dag: DAG) -> None:
operator = WherobotsRunOperator(
region=Region.AWS_US_WEST_2,
wherobots_conn_id=prod_conn.conn_id,
task_id="test_run_smoke",
name="airflow_operator_test_run_{{ ts_nodash }}",
run_python={
"uri": "s3://wbts-wbc-m97rcg45xi/42ly7mi0p1/data/shared/tile-generation-example.py"
},
dag=dag,
do_xcom_push=True,
timeout_seconds=2,
)
ti = build_ti(dag, task_id=operator.task_id)
with pytest.raises(RuntimeError) as e:
ti.run(ignore_ti_state=True)
assert "failed due to timeout" in str(e.value)
assert ti.state == TaskInstanceState.FAILED
Empty file.
209 changes: 209 additions & 0 deletions tests/unit_tests/operators/test_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
Test the operators in run module
"""

import datetime
import itertools
import uuid
from typing import Tuple, List
from unittest.mock import MagicMock

import pendulum
import pytest
from airflow import DAG
from airflow.models import DagRun, TaskInstance
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunType
from pytest_mock import MockerFixture
from wherobots.db import Region

from airflow_providers_wherobots.operators.run import WherobotsRunOperator
from airflow_providers_wherobots.wherobots.models import (
RunStatus,
LogsResponse,
Run,
LogItem,
)
from tests.unit_tests import helpers
from tests.unit_tests.helpers import run_factory

DEFAULT_START = pendulum.datetime(2021, 9, 13, tz="UTC")
DEFAULT_END = DEFAULT_START + datetime.timedelta(days=1)

TEST_DAG_ID = "test_run_operator"
TEST_TASK_ID = "run_operator"


def build_ti(dag: DAG, task_id: str, start=DEFAULT_START, end=DEFAULT_END):
dag_run: DagRun = dag.create_dagrun(
state=DagRunState.RUNNING,
execution_date=start,
data_interval=(start, end),
start_date=start,
run_type=DagRunType.MANUAL,
)
ti: TaskInstance = dag_run.get_task_instance(task_id=task_id)
ti.task = dag.get_task(task_id=task_id)
return ti


def execute_dag(dag: DAG, task_id: str, start=DEFAULT_START, end=DEFAULT_END):
ti = build_ti(dag, task_id, start=start, end=end)
ti.run(ignore_ti_state=True)


class TestWherobotsRunOperator:
@pytest.mark.usefixtures("clean_airflow_db")
def test_render_template(self, mocker: MockerFixture, dag: DAG):
data_interval_start = pendulum.datetime(2021, 9, 13, tz="UTC")
create_run: MagicMock = mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.create_run",
return_value=run_factory.build(status=RunStatus.COMPLETED),
)
operator = WherobotsRunOperator(
region=Region.AWS_US_WEST_2,
task_id="test_render_template_python",
name="test_run_{{ ds }}",
run_python={
"uri": "s3://bucket/test-{{ ds }}.py",
"args": ["{{ ds }}"],
},
dag=dag,
)
assert operator.region == Region.AWS_US_WEST_2
execute_dag(dag, task_id=operator.task_id)
assert create_run.call_count == 1
rendered_payload = create_run.call_args.args[0]
assert isinstance(rendered_payload, dict)
expected_ds = data_interval_start.format("YYYY-MM-DD")
assert rendered_payload["name"] == f"test_run_{expected_ds}"
assert (
rendered_payload["runPython"]["uri"] == f"s3://bucket/test-{expected_ds}.py"
)
assert rendered_payload["runPython"]["args"] == [expected_ds]

@pytest.mark.usefixtures("clean_airflow_db")
def test_default_name(self, mocker: MockerFixture, dag: DAG):
data_interval_start = pendulum.datetime(2021, 9, 13, tz="UTC")
create_run: MagicMock = mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.create_run",
return_value=run_factory.build(status=RunStatus.COMPLETED),
)
operator = WherobotsRunOperator(
task_id="test_default_name",
run_python={"uri": ""},
dag=dag,
)
execute_dag(dag, task_id=operator.task_id)
rendered_payload = create_run.call_args.args[0]
assert isinstance(rendered_payload, dict)
assert rendered_payload["name"] == operator.default_run_name.replace(
"{{ ts_nodash }}", data_interval_start.strftime("%Y%m%dT%H%M%S")
)

@pytest.mark.usefixtures("clean_airflow_db")
@pytest.mark.parametrize(
"poll_logs,test_item",
itertools.product(
[False, True],
[
(
[
run_factory.build(status=RunStatus.RUNNING),
run_factory.build(status=RunStatus.FAILED),
],
TaskInstanceState.FAILED,
),
(
[run_factory.build(status=RunStatus.CANCELLED)],
TaskInstanceState.FAILED,
),
(
[
run_factory.build(status=RunStatus.RUNNING),
run_factory.build(status=RunStatus.COMPLETED),
],
TaskInstanceState.SUCCESS,
),
],
),
)
def test_execute_handle_states(
self,
mocker: MockerFixture,
dag: DAG,
poll_logs: bool,
test_item: Tuple[List[Run], TaskInstanceState],
):
get_run_results, task_state = test_item
mocked_create_run = mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.create_run",
return_value=run_factory.build(status=RunStatus.PENDING),
)
mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.get_run",
side_effect=get_run_results,
)
if poll_logs:
mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.get_run_logs",
return_value=LogsResponse(items=[], current_page=0, next_page=None),
)
operator = WherobotsRunOperator(
task_id=f"test_execute_{uuid.uuid4()}",
run_python={"uri": ""},
dag=dag,
polling_interval=0,
poll_logs=poll_logs,
do_xcom_push=True,
)
ti = build_ti(dag, task_id=operator.task_id)
try:
ti.run(ignore_ti_state=True)
except Exception as e:
assert isinstance(e, RuntimeError)
assert ti.state == task_state
# test xcom push
if task_state == TaskInstanceState.SUCCESS:
ti.xcom_push("key", "value")
assert ti.xcom_pull(key="run_id") == mocked_create_run.return_value.ext_id

def test_on_kill(
self,
dag: DAG,
mocker: MockerFixture,
):
mocked_cancel_run: MagicMock = mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.cancel_run"
)
operator = WherobotsRunOperator(
task_id="test_render_template_python",
name="test_run_{{ ds }}",
run_python={
"uri": "s3://bucket/test-{{ ds }}.py",
"args": ["{{ ds }}"],
"entrypoint": "src.main_{{ ds }}",
},
dag=dag,
)
operator.on_kill()
assert mocked_cancel_run.call_count == 0
operator.run_id = "test_run_id"
operator.on_kill()
mocked_cancel_run.assert_called_with(operator.run_id)

def test_poll_and_display_logs(self, mocker: MockerFixture):
hook = mocker.MagicMock()
test_run: Run = helpers.run_factory.build()
hook.get_run_logs.return_value = LogsResponse(
items=[LogItem(raw="log1", timestamp=1), LogItem(raw="log2", timestamp=2)],
current_page=1,
next_page=2,
)
operator = WherobotsRunOperator(
task_id="test_poll_and_display_logs",
run_python={"uri": ""},
dag=DAG("test_poll_and_display_logs"),
)
assert operator.poll_and_display_logs(hook, test_run, 0) == 2
hook.get_run_logs.assert_called_with(test_run.ext_id, 0)
33 changes: 33 additions & 0 deletions tests/unit_tests/operators/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Test operators
"""

from unittest import mock
from unittest.mock import MagicMock

from wherobots.db import Runtime

from airflow_providers_wherobots.operators.sql import WherobotsSqlOperator


def mock_wherobots_db_connection():
mock_connection = MagicMock()
mock_cursor = MagicMock(rowcount=1)
mock_connection.cursor.return_value = mock_cursor
mock_cursor.fetchall.return_value = [("1",)]
return mock_connection


class TestWherobotsSqlOperator:
@mock.patch(
"airflow_providers_wherobots.hooks.sql.connect",
return_value=mock_wherobots_db_connection(),
)
def test_default_handler(self, mock_connect):
# Instantiate hook
operator = WherobotsSqlOperator(
task_id="test_task",
sql="select * from table_a",
runtime=Runtime.LARGE,
)
operator.execute(context={})

0 comments on commit 33128af

Please sign in to comment.