Skip to content

Commit

Permalink
Move timeout handling to functions wrapped with AsyncStatus (#318)
Browse files Browse the repository at this point in the history
* Make callbacks fire on AsyncStatus cancellation

* Move timeout handling to AsyncStatus coros

* Allow timeouts to be overridden

* Make test clearer

* Unify timeout handling between signal and motor
  • Loading branch information
coretl authored May 20, 2024
1 parent 302597c commit d2ed67e
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 235 deletions.
4 changes: 4 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from .standard_readable import ConfigSignal, HintedSignal, StandardReadable
from .utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
CalculateTimeout,
Callback,
NotConnected,
ReadingValueCallback,
Expand Down Expand Up @@ -108,6 +110,8 @@
"TriggerInfo",
"TriggerLogic",
"HardwareTriggeredFlyable",
"CalculateTimeout",
"CalculatableTimeout",
"DEFAULT_TIMEOUT",
"Callback",
"NotConnected",
Expand Down
39 changes: 11 additions & 28 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Awaitable,
Callable,
Generic,
SupportsFloat,
Type,
TypeVar,
cast,
Expand All @@ -27,15 +26,11 @@
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""

def __init__(self, awaitable: Awaitable, timeout: SupportsFloat | None = None):
if isinstance(timeout, SupportsFloat):
timeout = float(timeout)
def __init__(self, awaitable: Awaitable):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(
asyncio.wait_for(awaitable, timeout=timeout)
)
self.task = asyncio.create_task(awaitable)
self.task.add_done_callback(self._run_callbacks)
self._callbacks: list[Callback[Status]] = []

Expand All @@ -49,9 +44,8 @@ def add_callback(self, callback: Callback[Status]):
self._callbacks.append(callback)

def _run_callbacks(self, task: asyncio.Task):
if not task.cancelled():
for callback in self._callbacks:
callback(self)
for callback in self._callbacks:
callback(self)

def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
Expand Down Expand Up @@ -93,14 +87,11 @@ def __repr__(self) -> str:
class AsyncStatus(AsyncStatusBase):
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
"""Wrap an async function in an AsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
# We can't type this more properly because Concatenate/ParamSpec doesn't
# yet support keywords
# https://peps.python.org/pep-0612/#concatenating-keyword-parameters
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)
return cls(f(*args, **kwargs))

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
Expand All @@ -110,15 +101,11 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(
self,
iterator: AsyncIterator[WatcherUpdate[T]],
timeout: SupportsFloat | None = None,
):
def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator), timeout)
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
Expand Down Expand Up @@ -146,14 +133,10 @@ def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes
'timeout' as an argument, this must be a float or None, and it
will be propagated to the status."""
"""Wrap an AsyncIterator in a WatchableAsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)
return cls(f(*args, **kwargs))

return cast(Callable[P, WAS], wrap_f)
40 changes: 23 additions & 17 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .device import Device
from .signal_backend import SignalBackend
from .soft_signal_backend import SoftSignalBackend
from .utils import DEFAULT_TIMEOUT, Callback, T
from .utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T


def _add_timeout(func):
Expand Down Expand Up @@ -213,15 +213,14 @@ async def unstage(self) -> None:
self._del_cache(self._get_cache().set_staged(False))


USE_DEFAULT_TIMEOUT = "USE_DEFAULT_TIMEOUT"


class SignalW(Signal[T], Movable):
"""Signal that can be set"""

def set(self, value: T, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
def set(
self, value: T, wait=True, timeout: CalculatableTimeout = CalculateTimeout
) -> AsyncStatus:
"""Set the value and return a status saying when it's done"""
if timeout is USE_DEFAULT_TIMEOUT:
if timeout is CalculateTimeout:
timeout = self._timeout

async def do_set():
Expand All @@ -248,9 +247,11 @@ async def locate(self) -> Location:
class SignalX(Signal):
"""Signal that puts the default value"""

def trigger(self, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
def trigger(
self, wait=True, timeout: CalculatableTimeout = CalculateTimeout
) -> AsyncStatus:
"""Trigger the action and return a status saying when it's done"""
if timeout is USE_DEFAULT_TIMEOUT:
if timeout is CalculateTimeout:
timeout = self._timeout
coro = self._backend.put(None, wait=wait, timeout=timeout)
return AsyncStatus(coro)
Expand All @@ -270,7 +271,7 @@ def soft_signal_r_and_setter(
datatype: Optional[Type[T]] = None,
initial_value: Optional[T] = None,
name: str = "",
) -> Tuple[SignalR[T], Callable[[T]]]:
) -> Tuple[SignalR[T], Callable[[T], None]]:
"""Returns a tuple of a read-only Signal and a callable through
which the signal can be internally modified within the device. Use
soft_signal_rw if you want a device that is externally modifiable
Expand Down Expand Up @@ -394,7 +395,7 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int):


async def observe_value(
signal: SignalR[T], timeout=None, done_status: Status | None = None
signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None
) -> AsyncGenerator[T, None]:
"""Subscribe to the value of a signal so it can be iterated from.
Expand All @@ -403,8 +404,12 @@ async def observe_value(
signal:
Call subscribe_value on this at the start, and clear_sub on it at the
end
timeout:
If given, how long to wait for each updated value in seconds. If an update
is not produced in this time then raise asyncio.TimeoutError
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.
Notes
-----
Expand All @@ -414,9 +419,7 @@ async def observe_value(
do_something_with(value)
"""

class StatusIsDone: ...

q: asyncio.Queue[T | StatusIsDone] = asyncio.Queue()
q: asyncio.Queue[T | Status] = asyncio.Queue()
if timeout is None:
get_value = q.get
else:
Expand All @@ -425,16 +428,19 @@ async def get_value():
return await asyncio.wait_for(q.get(), timeout)

if done_status is not None:
done_status.add_callback(lambda _: q.put_nowait(StatusIsDone()))
done_status.add_callback(q.put_nowait)

signal.subscribe_value(q.put_nowait)
try:
while True:
item = await get_value()
if not isinstance(item, StatusIsDone):
yield item
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
else:
break
else:
break
yield item
finally:
signal.clear_sub(q.put_nowait)

Expand Down
11 changes: 11 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
ErrorText = Union[str, Dict[str, Exception]]


class CalculateTimeout:
"""Sentinel class used to implement ``myfunc(timeout=CalculateTimeout)``
This signifies that the function should calculate a suitable non-zero
timeout itself
"""


CalculatableTimeout = float | None | Type[CalculateTimeout]


class NotConnected(Exception):
"""Exception to be raised if a `Device.connect` is cancelled"""

Expand Down
56 changes: 30 additions & 26 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import string
import subprocess
import sys
from dataclasses import replace
from enum import Enum
from pathlib import Path

Expand All @@ -22,7 +21,13 @@
WatchableAsyncStatus,
observe_value,
)
from ophyd_async.core.utils import WatcherUpdate
from ophyd_async.core.async_status import AsyncStatus
from ophyd_async.core.utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
CalculateTimeout,
WatcherUpdate,
)

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -66,11 +71,9 @@ def __init__(self, prefix: str, name="") -> None:
# Define some signals
with self.add_children_as_readables(HintedSignal):
self.readback = epics_signal_r(float, prefix + "Readback")

with self.add_children_as_readables(ConfigSignal):
self.velocity = epics_signal_rw(float, prefix + "Velocity")
self.units = epics_signal_r(str, prefix + "Readback.EGU")

self.setpoint = epics_signal_rw(float, prefix + "Setpoint")
self.precision = epics_signal_r(int, prefix + "Readback.PREC")
# Signals that collide with standard methods should have a trailing underscore
Expand All @@ -85,41 +88,42 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.readback.set_name(name)

async def _move(self, new_position: float):
@WatchableAsyncStatus.wrap
async def set(
self, new_position: float, timeout: CalculatableTimeout = CalculateTimeout
):
self._set_success = True
# time.monotonic won't go backwards in case of NTP corrections
old_position, units, precision = await asyncio.gather(
old_position, units, precision, velocity = await asyncio.gather(
self.setpoint.get_value(),
self.units.get_value(),
self.precision.get_value(),
self.velocity.get_value(),
)
if timeout is CalculateTimeout:
assert velocity > 0, "Mover has zero velocity"
timeout = abs(new_position - old_position) / velocity + DEFAULT_TIMEOUT
# Make an Event that will be set on completion, and a Status that will
# error if not done in time
done = asyncio.Event()
done_status = AsyncStatus(asyncio.wait_for(done.wait(), timeout))
# Wait for the value to set, but don't wait for put completion callback
move_status = self.setpoint.set(new_position, wait=True)
if not self._set_success:
raise RuntimeError("Motor was stopped")
# return a template to set() which it can use to yield progress updates
return (
WatcherUpdate(
await self.setpoint.set(new_position, wait=False)
async for current_position in observe_value(
self.readback, done_status=done_status
):
yield WatcherUpdate(
current=current_position,
initial=old_position,
current=old_position,
target=new_position,
name=self.name,
unit=units,
precision=precision,
),
move_status,
)

@WatchableAsyncStatus.wrap # uses the timeout argument from the function it wraps
async def set(self, new_position: float, timeout: float | None = None):
update, _ = await self._move(new_position)
async for current_position in observe_value(self.readback):
yield replace(
update,
name=self.name,
current=current_position,
)
if np.isclose(current_position, new_position):
done.set()
break
if not self._set_success:
raise RuntimeError("Motor was stopped")

async def stop(self, success=True):
self._set_success = success
Expand Down
Loading

0 comments on commit d2ed67e

Please sign in to comment.