diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index 6e5e03d8d163..a259b455776b 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -17,16 +17,21 @@ # under the License. from __future__ import annotations +import inspect import logging from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, Protocol, TypeVar from airflow import settings +from airflow.assets.metadata import Metadata +from airflow.typing_compat import ParamSpec from airflow.utils.context import Context, lazy_mapping_from_context +from airflow.utils.types import NOTSET if TYPE_CHECKING: from airflow.utils.context import OutletEventAccessors +P = ParamSpec("P") R = TypeVar("R") DEFAULT_FORMAT_PREFIX = "airflow.ctx." @@ -225,7 +230,17 @@ def kwargs_func(*args, **kwargs): return kwargs_func -class ExecutionCallableRunner: +class _ExecutionCallableRunner(Protocol): + @staticmethod + def run(*args, **kwargs): ... + + +def ExecutionCallableRunner( + func: Callable[P, R], + outlet_events: OutletEventAccessors, + *, + logger: logging.Logger, +) -> _ExecutionCallableRunner: """ Run an execution callable against a task context and given arguments. @@ -234,45 +249,42 @@ class ExecutionCallableRunner: the generator is exhausted here, with the yielded values getting fed back into the task context automatically for execution. + This convoluted implementation of inner class with closure is so *all* + arguments passed to ``run()`` can be forwarded to the wrapped function. This + is particularly important for the argument "self", which some use cases + need to receive. This is not possible if this is implemented as a normal + class, where "self" needs to point to the ExecutionCallableRunner object. + + The function name violates PEP 8 due to backward compatibility. This was + implemented as a class previously. + :meta private: """ - def __init__( - self, - func: Callable, - outlet_events: OutletEventAccessors, - *, - logger: logging.Logger | None, - ) -> None: - self.func = func - self.outlet_events = outlet_events - self.logger = logger or logging.getLogger(__name__) - - def run(self, *args, **kwargs) -> Any: - import inspect - - from airflow.assets.metadata import Metadata - from airflow.utils.types import NOTSET + class _ExecutionCallableRunnerImpl: + @staticmethod + def run(*args: P.args, **kwargs: P.kwargs) -> R: + if not inspect.isgeneratorfunction(func): + return func(*args, **kwargs) - if not inspect.isgeneratorfunction(self.func): - return self.func(*args, **kwargs) + result: Any = NOTSET - result: Any = NOTSET + def _run(): + nonlocal result + result = yield from func(*args, **kwargs) - def _run(): - nonlocal result - result = yield from self.func(*args, **kwargs) + for metadata in _run(): + if isinstance(metadata, Metadata): + outlet_events[metadata.uri].extra.update(metadata.extra) - for metadata in _run(): - if isinstance(metadata, Metadata): - self.outlet_events[metadata.uri].extra.update(metadata.extra) + if metadata.alias_name: + outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra) - if metadata.alias_name: - self.outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra) + continue + logger.warning("Ignoring unknown data of %r received from task", type(metadata)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Full yielded value: %r", metadata) - continue - self.logger.warning("Ignoring unknown data of %r received from task", type(metadata)) - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug("Full yielded value: %r", metadata) + return result - return result + return _ExecutionCallableRunnerImpl