Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @kubernetes decorator #17248

Merged
merged 5 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

import inspect
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Iterable,
NoReturn,
Optional,
TypeVar,
overload,
)

from prefect_kubernetes.worker import KubernetesWorker
from typing_extensions import Literal, ParamSpec

from prefect import Flow, State
from prefect.futures import PrefectFuture
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.callables import get_call_parameters

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


class InfrastructureBoundFlow(Flow[P, R]):
def __init__(
self,
*args: Any,
work_pool: str,
job_variables: dict[str, Any],
# TODO: Update this to use BaseWorker when the .submit method is moved to the base class
worker_cls: type[KubernetesWorker],
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.work_pool = work_pool
self.job_variables = job_variables
self.worker_cls = worker_cls

@classmethod
def from_flow(
cls,
flow: Flow[P, R],
work_pool: str,
job_variables: dict[str, Any],
worker_cls: type[KubernetesWorker],
) -> InfrastructureBoundFlow[P, R]:
new = cls(
flow.fn,
work_pool=work_pool,
job_variables=job_variables,
worker_cls=worker_cls,
)
# Copy all attributes from the original flow
for attr, value in flow.__dict__.items():
setattr(new, attr, value)
return new

@overload
def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> None:
# `NoReturn` matches if a type can't be inferred for the function which stops a
# sync function from matching the `Coroutine` overload
...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]",
*args: P.args,
**kwargs: P.kwargs,
) -> Coroutine[Any, Any, T]: ...

@overload
def __call__(
self: "Flow[P, T]",
*args: P.args,
**kwargs: P.kwargs,
) -> T: ...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]",
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> Awaitable[State[T]]: ...

@overload
def __call__(
self: "Flow[P, T]",
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> State[T]: ...

def __call__(
self,
*args: "P.args",
return_state: bool = False,
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
**kwargs: "P.kwargs",
):
async def modified_call(
*args: P.args,
return_state: bool = False,
# TODO: Handle wait_for once we have an asynchronous way to wait for futures
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
**kwargs: P.kwargs,
) -> R | State[R]:
async with self.worker_cls(work_pool_name=self.work_pool) as worker:
parameters = get_call_parameters(self, args, kwargs)
future = await worker.submit(
flow=self,
parameters=parameters,
job_variables=self.job_variables,
)
if return_state:
await future.wait_async()
return future.state
return await future.aresult()

if inspect.iscoroutinefunction(self.fn):
return modified_call(
*args, return_state=return_state, wait_for=wait_for, **kwargs
)
else:
return run_coro_as_sync(
modified_call(
*args,
return_state=return_state,
wait_for=wait_for,
**kwargs,
)
)


def kubernetes(
work_pool: str, **job_variables: Any
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be nice to have a TypedDict here to Unpack someday, not sure how we'd manage that, but just throwing it out there

Suggested change
work_pool: str, **job_variables: Any
work_pool: str, **job_variables: Any

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see a world where we can generate a decorator for a specific work pool that includes the job variables as named and type kwargs. We're probably a ways off from that though.

) -> Callable[[Flow[P, R]], Flow[P, R]]:
"""
Decorator that binds execution of a flow to a Kubernetes work pool

Args:
work_pool: The name of the Kubernetes work pool to use
**job_variables: Additional job variables to use for infrastructure configuration

Example:
```python
from prefect import flow
from prefect_kubernetes import kubernetes

@kubernetes(work_pool="my-pool")
@flow
def my_flow():
...

# This will run the flow in a Kubernetes job
my_flow()
```
"""

def decorator(flow: Flow[P, R]) -> InfrastructureBoundFlow[P, R]:
return InfrastructureBoundFlow.from_flow(
flow,
work_pool=work_pool,
job_variables=job_variables,
worker_cls=KubernetesWorker,
)

return decorator
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Generator
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from prefect_kubernetes.experimental.decorators import kubernetes
from prefect_kubernetes.worker import KubernetesWorker

from prefect import State, flow
from prefect.futures import PrefectFuture


@pytest.fixture
def mock_submit() -> Generator[AsyncMock, None, None]:
"""Create a mock for the KubernetesWorker.submit method"""
# Create a mock state
mock_state = MagicMock(spec=State)
mock_state.is_completed.return_value = True
mock_state.message = "Success"

# Create a mock future
mock_future = MagicMock(spec=PrefectFuture)
mock_future.aresult = AsyncMock(return_value="test_result")
mock_future.wait_async = AsyncMock()
mock_future.state = mock_state

mock = AsyncMock(return_value=mock_future)

patcher = patch.object(KubernetesWorker, "submit", mock)
patcher.start()
yield mock
patcher.stop()


def test_kubernetes_decorator_sync_flow(mock_submit: AsyncMock) -> None:
"""Test that a synchronous flow is correctly decorated and executed"""

@kubernetes(work_pool="test-pool", memory="2Gi")
@flow
def sync_test_flow(param1, param2="default"):
return f"{param1}-{param2}"

result = sync_test_flow("test")

mock_submit.assert_called_once()
args, kwargs = mock_submit.call_args
assert kwargs["parameters"] == {"param1": "test", "param2": "default"}
assert kwargs["job_variables"] == {"memory": "2Gi"}
assert result == "test_result"


async def test_kubernetes_decorator_async_flow(mock_submit: AsyncMock) -> None:
"""Test that an asynchronous flow is correctly decorated and executed"""

@kubernetes(work_pool="test-pool", cpu="1")
@flow
async def async_test_flow(param1):
return f"async-{param1}"

result = await async_test_flow("test")

mock_submit.assert_called_once()
args, kwargs = mock_submit.call_args
assert kwargs["parameters"] == {"param1": "test"}
assert kwargs["job_variables"] == {"cpu": "1"}
assert result == "test_result"


@pytest.mark.usefixtures("mock_submit")
def test_kubernetes_decorator_return_state() -> None:
"""Test that return_state=True returns the state instead of the result"""

@kubernetes(work_pool="test-pool")
@flow
def test_flow():
return "completed"

state = test_flow(return_state=True)

assert state.is_completed()
assert state.message == "Success"


@pytest.mark.usefixtures("mock_submit")
def test_kubernetes_decorator_preserves_flow_attributes() -> None:
"""Test that the decorator preserves the original flow's attributes"""

@flow(name="custom-flow-name", description="Custom description")
def original_flow():
return "test"

original_name = original_flow.name
original_description = original_flow.description

decorated_flow = kubernetes(work_pool="test-pool")(original_flow)

assert decorated_flow.name == original_name
assert decorated_flow.description == original_description

result = decorated_flow()
assert result == "test_result"


def test_submit_method_receives_work_pool_name(mock_submit: AsyncMock) -> None:
"""Test that the correct work pool name is passed to submit"""

@kubernetes(work_pool="specific-pool")
@flow
def test_flow():
return "test"

test_flow()

mock_submit.assert_called_once()
kwargs = mock_submit.call_args.kwargs
assert "flow" in kwargs
assert "parameters" in kwargs
assert "job_variables" in kwargs