-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add @kubernetes
decorator
#17248
Merged
Merged
Add @kubernetes
decorator
#17248
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
91cd1b1
Adds `@kubernetes` decorator
desertaxle e3eeb8f
Update decorator to maintain flowiness
desertaxle 7a6324f
Move class outside of decorator
desertaxle 29934e9
initial test coverage for kubernetes decorator (#17275)
zzstoatzz 8b3d8d3
Lil' cleanup
desertaxle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
173 changes: 173 additions & 0 deletions
173
src/integrations/prefect-kubernetes/prefect_kubernetes/experimental/decorators.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
from typing import ( | ||
Any, | ||
Awaitable, | ||
Callable, | ||
Coroutine, | ||
Iterable, | ||
NoReturn, | ||
Optional, | ||
TypeVar, | ||
overload, | ||
) | ||
|
||
from prefect_kubernetes.worker import KubernetesWorker | ||
from typing_extensions import Literal, ParamSpec | ||
|
||
from prefect import Flow, State | ||
from prefect.futures import PrefectFuture | ||
from prefect.utilities.asyncutils import run_coro_as_sync | ||
from prefect.utilities.callables import get_call_parameters | ||
|
||
P = ParamSpec("P") | ||
R = TypeVar("R") | ||
T = TypeVar("T") | ||
|
||
|
||
class InfrastructureBoundFlow(Flow[P, R]): | ||
def __init__( | ||
self, | ||
*args: Any, | ||
work_pool: str, | ||
job_variables: dict[str, Any], | ||
# TODO: Update this to use BaseWorker when the .submit method is moved to the base class | ||
worker_cls: type[KubernetesWorker], | ||
**kwargs: Any, | ||
): | ||
super().__init__(*args, **kwargs) | ||
self.work_pool = work_pool | ||
self.job_variables = job_variables | ||
self.worker_cls = worker_cls | ||
|
||
@classmethod | ||
def from_flow( | ||
cls, | ||
flow: Flow[P, R], | ||
work_pool: str, | ||
job_variables: dict[str, Any], | ||
worker_cls: type[KubernetesWorker], | ||
) -> InfrastructureBoundFlow[P, R]: | ||
new = cls( | ||
flow.fn, | ||
work_pool=work_pool, | ||
job_variables=job_variables, | ||
worker_cls=worker_cls, | ||
) | ||
# Copy all attributes from the original flow | ||
for attr, value in flow.__dict__.items(): | ||
setattr(new, attr, value) | ||
return new | ||
|
||
@overload | ||
def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> None: | ||
# `NoReturn` matches if a type can't be inferred for the function which stops a | ||
# sync function from matching the `Coroutine` overload | ||
... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, Coroutine[Any, Any, T]]", | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> Coroutine[Any, Any, T]: ... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, T]", | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> T: ... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, Coroutine[Any, Any, T]]", | ||
*args: P.args, | ||
return_state: Literal[True], | ||
**kwargs: P.kwargs, | ||
) -> Awaitable[State[T]]: ... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, T]", | ||
*args: P.args, | ||
return_state: Literal[True], | ||
**kwargs: P.kwargs, | ||
) -> State[T]: ... | ||
|
||
def __call__( | ||
self, | ||
*args: "P.args", | ||
return_state: bool = False, | ||
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, | ||
**kwargs: "P.kwargs", | ||
): | ||
async def modified_call( | ||
*args: P.args, | ||
return_state: bool = False, | ||
# TODO: Handle wait_for once we have an asynchronous way to wait for futures | ||
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, | ||
**kwargs: P.kwargs, | ||
) -> R | State[R]: | ||
async with self.worker_cls(work_pool_name=self.work_pool) as worker: | ||
parameters = get_call_parameters(self, args, kwargs) | ||
future = await worker.submit( | ||
flow=self, | ||
parameters=parameters, | ||
job_variables=self.job_variables, | ||
) | ||
if return_state: | ||
await future.wait_async() | ||
return future.state | ||
return await future.aresult() | ||
|
||
if inspect.iscoroutinefunction(self.fn): | ||
return modified_call( | ||
*args, return_state=return_state, wait_for=wait_for, **kwargs | ||
) | ||
else: | ||
return run_coro_as_sync( | ||
modified_call( | ||
*args, | ||
return_state=return_state, | ||
wait_for=wait_for, | ||
**kwargs, | ||
) | ||
) | ||
|
||
|
||
def kubernetes( | ||
work_pool: str, **job_variables: Any | ||
) -> Callable[[Flow[P, R]], Flow[P, R]]: | ||
""" | ||
Decorator that binds execution of a flow to a Kubernetes work pool | ||
|
||
Args: | ||
work_pool: The name of the Kubernetes work pool to use | ||
**job_variables: Additional job variables to use for infrastructure configuration | ||
|
||
Example: | ||
```python | ||
from prefect import flow | ||
from prefect_kubernetes import kubernetes | ||
|
||
@kubernetes(work_pool="my-pool") | ||
@flow | ||
def my_flow(): | ||
... | ||
|
||
# This will run the flow in a Kubernetes job | ||
my_flow() | ||
``` | ||
""" | ||
|
||
def decorator(flow: Flow[P, R]) -> InfrastructureBoundFlow[P, R]: | ||
return InfrastructureBoundFlow.from_flow( | ||
flow, | ||
work_pool=work_pool, | ||
job_variables=job_variables, | ||
worker_cls=KubernetesWorker, | ||
) | ||
|
||
return decorator |
117 changes: 117 additions & 0 deletions
117
src/integrations/prefect-kubernetes/tests/experimental/test_decorator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Generator | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
from prefect_kubernetes.experimental.decorators import kubernetes | ||
from prefect_kubernetes.worker import KubernetesWorker | ||
|
||
from prefect import State, flow | ||
from prefect.futures import PrefectFuture | ||
|
||
|
||
@pytest.fixture | ||
def mock_submit() -> Generator[AsyncMock, None, None]: | ||
"""Create a mock for the KubernetesWorker.submit method""" | ||
# Create a mock state | ||
mock_state = MagicMock(spec=State) | ||
mock_state.is_completed.return_value = True | ||
mock_state.message = "Success" | ||
|
||
# Create a mock future | ||
mock_future = MagicMock(spec=PrefectFuture) | ||
mock_future.aresult = AsyncMock(return_value="test_result") | ||
mock_future.wait_async = AsyncMock() | ||
mock_future.state = mock_state | ||
|
||
mock = AsyncMock(return_value=mock_future) | ||
|
||
patcher = patch.object(KubernetesWorker, "submit", mock) | ||
patcher.start() | ||
yield mock | ||
patcher.stop() | ||
|
||
|
||
def test_kubernetes_decorator_sync_flow(mock_submit: AsyncMock) -> None: | ||
"""Test that a synchronous flow is correctly decorated and executed""" | ||
|
||
@kubernetes(work_pool="test-pool", memory="2Gi") | ||
@flow | ||
def sync_test_flow(param1, param2="default"): | ||
return f"{param1}-{param2}" | ||
|
||
result = sync_test_flow("test") | ||
|
||
mock_submit.assert_called_once() | ||
args, kwargs = mock_submit.call_args | ||
assert kwargs["parameters"] == {"param1": "test", "param2": "default"} | ||
assert kwargs["job_variables"] == {"memory": "2Gi"} | ||
assert result == "test_result" | ||
|
||
|
||
async def test_kubernetes_decorator_async_flow(mock_submit: AsyncMock) -> None: | ||
"""Test that an asynchronous flow is correctly decorated and executed""" | ||
|
||
@kubernetes(work_pool="test-pool", cpu="1") | ||
@flow | ||
async def async_test_flow(param1): | ||
return f"async-{param1}" | ||
|
||
result = await async_test_flow("test") | ||
|
||
mock_submit.assert_called_once() | ||
args, kwargs = mock_submit.call_args | ||
assert kwargs["parameters"] == {"param1": "test"} | ||
assert kwargs["job_variables"] == {"cpu": "1"} | ||
assert result == "test_result" | ||
|
||
|
||
@pytest.mark.usefixtures("mock_submit") | ||
def test_kubernetes_decorator_return_state() -> None: | ||
"""Test that return_state=True returns the state instead of the result""" | ||
|
||
@kubernetes(work_pool="test-pool") | ||
@flow | ||
def test_flow(): | ||
return "completed" | ||
|
||
state = test_flow(return_state=True) | ||
|
||
assert state.is_completed() | ||
assert state.message == "Success" | ||
|
||
|
||
@pytest.mark.usefixtures("mock_submit") | ||
def test_kubernetes_decorator_preserves_flow_attributes() -> None: | ||
"""Test that the decorator preserves the original flow's attributes""" | ||
|
||
@flow(name="custom-flow-name", description="Custom description") | ||
def original_flow(): | ||
return "test" | ||
|
||
original_name = original_flow.name | ||
original_description = original_flow.description | ||
|
||
decorated_flow = kubernetes(work_pool="test-pool")(original_flow) | ||
|
||
assert decorated_flow.name == original_name | ||
assert decorated_flow.description == original_description | ||
|
||
result = decorated_flow() | ||
assert result == "test_result" | ||
|
||
|
||
def test_submit_method_receives_work_pool_name(mock_submit: AsyncMock) -> None: | ||
"""Test that the correct work pool name is passed to submit""" | ||
|
||
@kubernetes(work_pool="specific-pool") | ||
@flow | ||
def test_flow(): | ||
return "test" | ||
|
||
test_flow() | ||
|
||
mock_submit.assert_called_once() | ||
kwargs = mock_submit.call_args.kwargs | ||
assert "flow" in kwargs | ||
assert "parameters" in kwargs | ||
assert "job_variables" in kwargs |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it'd be nice to have a
TypedDict
here toUnpack
someday, not sure how we'd manage that, but just throwing it out thereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see a world where we can generate a decorator for a specific work pool that includes the job variables as named and type kwargs. We're probably a ways off from that though.