Skip to content

Commit

Permalink
Update decorator to maintain flowiness
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Feb 24, 2025
1 parent 1b666d4 commit d9ade40
Showing 1 changed file with 101 additions and 24 deletions.
125 changes: 101 additions & 24 deletions src/integrations/prefect-kubernetes/prefect_kubernetes/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any, Callable, TypeVar
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Iterable,
NoReturn,
Optional,
TypeVar,
overload,
)

from typing_extensions import ParamSpec
from typing_extensions import Literal, ParamSpec

from prefect import Flow
from prefect import Flow, State
from prefect.futures import PrefectFuture
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.callables import get_call_parameters
from prefect_kubernetes.worker import KubernetesWorker

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


def kubernetes(
work_pool: str, **job_variables: Any
) -> Callable[[Flow[P, R]], Callable[P, R]]:
) -> Callable[[Flow[P, R]], Flow[P, R]]:
"""
Decorator that binds execution of a flow to a Kubernetes work pool
Expand All @@ -40,24 +51,90 @@ def my_flow():
```
"""

def decorator(flow: Flow[P, R]) -> Callable[P, R]:
@wraps(flow)
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> R:
async with KubernetesWorker(work_pool_name=work_pool) as worker:
parameters = get_call_parameters(flow, args, kwargs)
future = await worker.submit(
flow=flow, parameters=parameters, job_variables=job_variables
)
return await future.aresult()

if inspect.iscoroutinefunction(flow.fn):
return awrapper
else:

@wraps(flow)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return run_coro_as_sync(awrapper(*args, **kwargs))

return wrapper
def decorator(flow: Flow[P, R]) -> Flow[P, R]:
class KubernetesFlow(Flow[P, R]):
@overload
def __call__(
self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs
) -> None:
# `NoReturn` matches if a type can't be inferred for the function which stops a
# sync function from matching the `Coroutine` overload
...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]",
*args: P.args,
**kwargs: P.kwargs,
) -> Coroutine[Any, Any, T]: ...

@overload
def __call__(
self: "Flow[P, T]",
*args: P.args,
**kwargs: P.kwargs,
) -> T: ...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]",
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> Awaitable[State[T]]: ...

@overload
def __call__(
self: "Flow[P, T]",
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> State[T]: ...

def __call__(
self,
*args: "P.args",
return_state: bool = False,
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
**kwargs: "P.kwargs",
):
async def modified_call(
*args: P.args,
return_state: bool = False,
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
**kwargs: P.kwargs,
) -> R | State[R]:
async with KubernetesWorker(work_pool_name=work_pool) as worker:
parameters = get_call_parameters(flow, args, kwargs)
future = await worker.submit(
flow=flow,
parameters=parameters,
job_variables=job_variables,
)
if return_state:
await future.wait_async()
return future.state
return await future.aresult()

if inspect.iscoroutinefunction(self.fn):
return modified_call(
*args, return_state=return_state, wait_for=wait_for, **kwargs
)
else:
return run_coro_as_sync(
modified_call(
*args,
return_state=return_state,
wait_for=wait_for,
**kwargs,
)
)

flow_copy = KubernetesFlow(flow.fn)

for attr, value in flow.__dict__.items():
setattr(flow_copy, attr, value)

return flow_copy

return decorator

0 comments on commit d9ade40

Please sign in to comment.