Skip to content

Commit

Permalink
Improve ExecutionCallableRunner (#43812)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Nov 8, 2024
1 parent 80e9a94 commit 50aabd2
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions airflow/utils/operator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
# under the License.
from __future__ import annotations

import inspect
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, Protocol, TypeVar

from airflow import settings
from airflow.assets.metadata import Metadata
from airflow.typing_compat import ParamSpec
from airflow.utils.context import Context, lazy_mapping_from_context
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.utils.context import OutletEventAccessors

P = ParamSpec("P")
R = TypeVar("R")

DEFAULT_FORMAT_PREFIX = "airflow.ctx."
Expand Down Expand Up @@ -225,7 +230,17 @@ def kwargs_func(*args, **kwargs):
return kwargs_func


class ExecutionCallableRunner:
class _ExecutionCallableRunner(Protocol):
@staticmethod
def run(*args, **kwargs): ...


def ExecutionCallableRunner(
func: Callable[P, R],
outlet_events: OutletEventAccessors,
*,
logger: logging.Logger,
) -> _ExecutionCallableRunner:
"""
Run an execution callable against a task context and given arguments.
Expand All @@ -234,45 +249,42 @@ class ExecutionCallableRunner:
the generator is exhausted here, with the yielded values getting fed back
into the task context automatically for execution.
This convoluted implementation of inner class with closure is so *all*
arguments passed to ``run()`` can be forwarded to the wrapped function. This
is particularly important for the argument "self", which some use cases
need to receive. This is not possible if this is implemented as a normal
class, where "self" needs to point to the ExecutionCallableRunner object.
The function name violates PEP 8 due to backward compatibility. This was
implemented as a class previously.
:meta private:
"""

def __init__(
self,
func: Callable,
outlet_events: OutletEventAccessors,
*,
logger: logging.Logger | None,
) -> None:
self.func = func
self.outlet_events = outlet_events
self.logger = logger or logging.getLogger(__name__)

def run(self, *args, **kwargs) -> Any:
import inspect

from airflow.assets.metadata import Metadata
from airflow.utils.types import NOTSET
class _ExecutionCallableRunnerImpl:
@staticmethod
def run(*args: P.args, **kwargs: P.kwargs) -> R:
if not inspect.isgeneratorfunction(func):
return func(*args, **kwargs)

if not inspect.isgeneratorfunction(self.func):
return self.func(*args, **kwargs)
result: Any = NOTSET

result: Any = NOTSET
def _run():
nonlocal result
result = yield from func(*args, **kwargs)

def _run():
nonlocal result
result = yield from self.func(*args, **kwargs)
for metadata in _run():
if isinstance(metadata, Metadata):
outlet_events[metadata.uri].extra.update(metadata.extra)

for metadata in _run():
if isinstance(metadata, Metadata):
self.outlet_events[metadata.uri].extra.update(metadata.extra)
if metadata.alias_name:
outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra)

if metadata.alias_name:
self.outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra)
continue
logger.warning("Ignoring unknown data of %r received from task", type(metadata))
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Full yielded value: %r", metadata)

continue
self.logger.warning("Ignoring unknown data of %r received from task", type(metadata))
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("Full yielded value: %r", metadata)
return result

return result
return _ExecutionCallableRunnerImpl

0 comments on commit 50aabd2

Please sign in to comment.