Skip to content

Commit

Permalink
Subscribe to state changes in wait_for_flow_run (#17243)
Browse files Browse the repository at this point in the history
  • Loading branch information
bnaul authored Feb 27, 2025
1 parent b8ff0bb commit 29d8afd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
1 change: 0 additions & 1 deletion src/prefect/cli/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,6 @@ async def run(
soft_wrap=True,
)
if watch:
watch_interval = 5 if watch_interval is None else watch_interval
app.console.print(f"Watching flow run {flow_run.name!r}...")
finished_flow_run = await wait_for_flow_run(
flow_run.id,
Expand Down
41 changes: 33 additions & 8 deletions src/prefect/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas import FlowRun
from prefect.client.schemas.objects import (
State,
StateType,
)
from prefect.client.schemas.responses import SetStateStatus
Expand All @@ -22,6 +23,8 @@
FlowRunContext,
TaskRunContext,
)
from prefect.events.clients import get_events_subscriber
from prefect.events.filters import EventFilter, EventNameFilter, EventResourceFilter
from prefect.exceptions import (
Abort,
FlowPauseTimeout,
Expand Down Expand Up @@ -54,7 +57,7 @@
async def wait_for_flow_run(
flow_run_id: UUID,
timeout: int | None = 10800,
poll_interval: int = 5,
poll_interval: int | None = None,
client: "PrefectClient | None" = None,
log_states: bool = False,
) -> FlowRun:
Expand All @@ -64,7 +67,9 @@ async def wait_for_flow_run(
Args:
flow_run_id: The flow run ID for the flow run to wait for.
timeout: The wait timeout in seconds. Defaults to 10800 (3 hours).
poll_interval: The poll interval in seconds. Defaults to 5.
poll_interval: Deprecated; polling is no longer used to wait for flow runs.
client: Optional Prefect client. If not provided, one will be injected.
log_states: If True, log state changes. Defaults to False.
Returns:
FlowRun: The finished flow run.
Expand Down Expand Up @@ -114,17 +119,37 @@ async def main(num_runs: int):
```
"""
if poll_interval is not None:
get_logger().warning(
"The `poll_interval` argument is deprecated and will be removed in a future release. "
)

assert client is not None, "Client injection failed"
logger = get_logger()

event_filter = EventFilter(
event=EventNameFilter(prefix=["prefect.flow-run"]),
resource=EventResourceFilter(id=[f"prefect.flow-run.{flow_run_id}"]),
)

with anyio.move_on_after(timeout):
while True:
async with get_events_subscriber(filter=event_filter) as subscriber:
flow_run = await client.read_flow_run(flow_run_id)
flow_state = flow_run.state
if log_states and flow_state:
logger.info(f"Flow run is in state {flow_state.name!r}")
if flow_state and flow_state.is_final():
if flow_run.state and flow_run.state.is_final():
if log_states:
logger.info(f"Flow run is in state {flow_run.state.name!r}")
return flow_run
await anyio.sleep(poll_interval)

async for event in subscriber:
state_type = StateType(event.resource["prefect.state-type"])
state = State(type=state_type)

if log_states:
logger.info(f"Flow run is in state {state.name!r}")

if state.is_final():
return await client.read_flow_run(flow_run_id)

raise FlowRunWaitTimeout(
f"Flow run with ID {flow_run_id} exceeded watch timeout of {timeout} seconds"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def foo():
flow_run = await prefect_client.create_flow_run(foo, state=Completed())
assert isinstance(flow_run, client_schemas.FlowRun)

lookup = await wait_for_flow_run(flow_run.id, poll_interval=0)
lookup = await wait_for_flow_run(flow_run.id)
# Estimates will not be equal since time has passed
assert lookup == flow_run
assert flow_run.state
Expand Down

0 comments on commit 29d8afd

Please sign in to comment.