diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index f96d3ba2a..d3466ea16 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -270,6 +270,7 @@ def create_task_metadata( } if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES: + args["on_warning_callback"] = on_warning_callback exclude_detached_tests_if_needed(node, args, detached_from_parent) task_id, args = _get_task_id_and_args( node, args, use_task_group, normalize_task_id, "build", include_resource_type=True diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 21fa6ae91..af64fa13a 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -762,8 +762,36 @@ class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + self.on_warning_callback = on_warning_callback + self.extract_issues: Callable[..., tuple[list[str], list[str]]] + + def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: + """ + Handles warnings by extracting log issues, creating additional context, and calling the + on_warning_callback with the updated context. + + :param result: The result object from the build and run command. + :param context: The original airflow context in which the build and run command was executed. + """ + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.extract_issues = extract_freshness_warn_msg + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.extract_issues = dbt_runner.extract_message_by_status + + test_names, test_results = self.extract_issues(result) + + warning_context = dict(context) + warning_context["test_names"] = test_names + warning_context["test_results"] = test_results + + self.on_warning_callback and self.on_warning_callback(warning_context) + + def execute(self, context: Context, **kwargs: Any) -> None: + result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) + if self.on_warning_callback: + self._handle_warnings(result, context) class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator): diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 3bcd78616..a721cd858 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -664,7 +664,17 @@ def test_run_test_operator_with_callback(invocation_mode, failing_test_dbt_proje on_warning_callback=on_warning_callback, invocation_mode=invocation_mode, ) - run_operator >> test_operator + + build_operator = DbtBuildLocalOperator( + profile_config=mini_profile_config, + project_dir=failing_test_dbt_project, + task_id="build", + append_env=True, + on_warning_callback=on_warning_callback, + invocation_mode=invocation_mode, + ) + + run_operator >> build_operator >> test_operator run_test_dag(dag) assert on_warning_callback.called