diff --git a/providers/src/airflow/providers/dbt/cloud/operators/dbt.py b/providers/src/airflow/providers/dbt/cloud/operators/dbt.py index c26e67e2a8b1..8795ebf0ca71 100644 --- a/providers/src/airflow/providers/dbt/cloud/operators/dbt.py +++ b/providers/src/airflow/providers/dbt/cloud/operators/dbt.py @@ -149,6 +149,8 @@ def execute(self, context: Context): self.run_id = non_terminal_runs[0]["id"] job_run_url = non_terminal_runs[0]["href"] + is_retry = context["ti"].try_number != 1 + if not self.reuse_existing_run or not non_terminal_runs: trigger_job_response = self.hook.trigger_job_run( account_id=self.account_id, @@ -156,7 +158,7 @@ def execute(self, context: Context): cause=self.trigger_reason, steps_override=self.steps_override, schema_override=self.schema_override, - retry_from_failure=self.retry_from_failure, + retry_from_failure=is_retry and self.retry_from_failure, additional_run_config=self.additional_run_config, ) self.run_id = trigger_job_response.json()["data"]["id"] diff --git a/providers/tests/dbt/cloud/operators/test_dbt.py b/providers/tests/dbt/cloud/operators/test_dbt.py index eb50bd5a22a2..a5f8752ffb3e 100644 --- a/providers/tests/dbt/cloud/operators/test_dbt.py +++ b/providers/tests/dbt/cloud/operators/test_dbt.py @@ -64,6 +64,17 @@ ), } } +JOB_RUN_ERROR_RESPONSE = { + "data": [ + { + "id": RUN_ID, + "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( + account_id=ACCOUNT_ID, project_id=PROJECT_ID, run_id=RUN_ID + ), + "status": DbtCloudJobRunStatus.ERROR.value, + } + ] +} def mock_response_json(response: dict): @@ -421,6 +432,73 @@ def test_execute_retry_from_failure(self, mock_run_job, conn_id, account_id): additional_run_config=self.config["additional_run_config"], ) + @patch.object(DbtCloudHook, "_run_and_get_response") + @pytest.mark.parametrize( + "conn_id, account_id", + [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], + ids=["default_account", "explicit_account"], + ) + def test_execute_retry_from_failure_run(self, mock_run_req, conn_id, account_id): + operator = DbtCloudRunJobOperator( + task_id=TASK_ID, + dbt_cloud_conn_id=conn_id, + account_id=account_id, + trigger_reason=None, + dag=self.dag, + retry_from_failure=True, + **self.config, + ) + self.mock_context["ti"].try_number = 1 + + assert operator.dbt_cloud_conn_id == conn_id + assert operator.job_id == self.config["job_id"] + assert operator.account_id == account_id + assert operator.check_interval == self.config["check_interval"] + assert operator.timeout == self.config["timeout"] + assert operator.retry_from_failure + assert operator.steps_override == self.config["steps_override"] + assert operator.schema_override == self.config["schema_override"] + assert operator.additional_run_config == self.config["additional_run_config"] + + operator.execute(context=self.mock_context) + + mock_run_req.assert_called() + + @patch.object( + DbtCloudHook, "_run_and_get_response", return_value=mock_response_json(JOB_RUN_ERROR_RESPONSE) + ) + @patch.object(DbtCloudHook, "retry_failed_job_run") + @pytest.mark.parametrize( + "conn_id, account_id", + [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], + ids=["default_account", "explicit_account"], + ) + def test_execute_retry_from_failure_rerun(self, mock_run_req, mock_rerun_req, conn_id, account_id): + operator = DbtCloudRunJobOperator( + task_id=TASK_ID, + dbt_cloud_conn_id=conn_id, + account_id=account_id, + trigger_reason=None, + dag=self.dag, + retry_from_failure=True, + **self.config, + ) + self.mock_context["ti"].try_number = 2 + + assert operator.dbt_cloud_conn_id == conn_id + assert operator.job_id == self.config["job_id"] + assert operator.account_id == account_id + assert operator.check_interval == self.config["check_interval"] + assert operator.timeout == self.config["timeout"] + assert operator.retry_from_failure + assert operator.steps_override == self.config["steps_override"] + assert operator.schema_override == self.config["schema_override"] + assert operator.additional_run_config == self.config["additional_run_config"] + + operator.execute(context=self.mock_context) + + mock_rerun_req.assert_called_once() + @patch.object(DbtCloudHook, "trigger_job_run") @pytest.mark.parametrize( "conn_id, account_id",