From 29d8afd9987cab1fc331bfec950e7ff9501a8787 Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Thu, 27 Feb 2025 13:18:42 -0500 Subject: [PATCH] Subscribe to state changes in wait_for_flow_run (#17243) --- src/prefect/cli/deployment.py | 1 - src/prefect/flow_runs.py | 41 ++++++++++++++++++++++++++++------- tests/test_flow_runs.py | 2 +- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/prefect/cli/deployment.py b/src/prefect/cli/deployment.py index f8ecb475df8d..66131e80a5b5 100644 --- a/src/prefect/cli/deployment.py +++ b/src/prefect/cli/deployment.py @@ -894,7 +894,6 @@ async def run( soft_wrap=True, ) if watch: - watch_interval = 5 if watch_interval is None else watch_interval app.console.print(f"Watching flow run {flow_run.name!r}...") finished_flow_run = await wait_for_flow_run( flow_run.id, diff --git a/src/prefect/flow_runs.py b/src/prefect/flow_runs.py index 5fdff334a4f4..4ff5e1967c8f 100644 --- a/src/prefect/flow_runs.py +++ b/src/prefect/flow_runs.py @@ -14,6 +14,7 @@ from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas import FlowRun from prefect.client.schemas.objects import ( + State, StateType, ) from prefect.client.schemas.responses import SetStateStatus @@ -22,6 +23,8 @@ FlowRunContext, TaskRunContext, ) +from prefect.events.clients import get_events_subscriber +from prefect.events.filters import EventFilter, EventNameFilter, EventResourceFilter from prefect.exceptions import ( Abort, FlowPauseTimeout, @@ -54,7 +57,7 @@ async def wait_for_flow_run( flow_run_id: UUID, timeout: int | None = 10800, - poll_interval: int = 5, + poll_interval: int | None = None, client: "PrefectClient | None" = None, log_states: bool = False, ) -> FlowRun: @@ -64,7 +67,9 @@ async def wait_for_flow_run( Args: flow_run_id: The flow run ID for the flow run to wait for. timeout: The wait timeout in seconds. Defaults to 10800 (3 hours). - poll_interval: The poll interval in seconds. Defaults to 5. + poll_interval: Deprecated; polling is no longer used to wait for flow runs. + client: Optional Prefect client. If not provided, one will be injected. + log_states: If True, log state changes. Defaults to False. Returns: FlowRun: The finished flow run. @@ -114,17 +119,37 @@ async def main(num_runs: int): ``` """ + if poll_interval is not None: + get_logger().warning( + "The `poll_interval` argument is deprecated and will be removed in a future release. " + ) + assert client is not None, "Client injection failed" logger = get_logger() + + event_filter = EventFilter( + event=EventNameFilter(prefix=["prefect.flow-run"]), + resource=EventResourceFilter(id=[f"prefect.flow-run.{flow_run_id}"]), + ) + with anyio.move_on_after(timeout): - while True: + async with get_events_subscriber(filter=event_filter) as subscriber: flow_run = await client.read_flow_run(flow_run_id) - flow_state = flow_run.state - if log_states and flow_state: - logger.info(f"Flow run is in state {flow_state.name!r}") - if flow_state and flow_state.is_final(): + if flow_run.state and flow_run.state.is_final(): + if log_states: + logger.info(f"Flow run is in state {flow_run.state.name!r}") return flow_run - await anyio.sleep(poll_interval) + + async for event in subscriber: + state_type = StateType(event.resource["prefect.state-type"]) + state = State(type=state_type) + + if log_states: + logger.info(f"Flow run is in state {state.name!r}") + + if state.is_final(): + return await client.read_flow_run(flow_run_id) + raise FlowRunWaitTimeout( f"Flow run with ID {flow_run_id} exceeded watch timeout of {timeout} seconds" ) diff --git a/tests/test_flow_runs.py b/tests/test_flow_runs.py index 042a6f4c6c1b..861e274c894c 100644 --- a/tests/test_flow_runs.py +++ b/tests/test_flow_runs.py @@ -18,7 +18,7 @@ def foo(): flow_run = await prefect_client.create_flow_run(foo, state=Completed()) assert isinstance(flow_run, client_schemas.FlowRun) - lookup = await wait_for_flow_run(flow_run.id, poll_interval=0) + lookup = await wait_for_flow_run(flow_run.id) # Estimates will not be equal since time has passed assert lookup == flow_run assert flow_run.state