Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move timeout handling to functions wrapped with AsyncStatus #318

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from .standard_readable import ConfigSignal, HintedSignal, StandardReadable
from .utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
CalculateTimeout,
Callback,
NotConnected,
ReadingValueCallback,
Expand Down Expand Up @@ -108,6 +110,8 @@
"TriggerInfo",
"TriggerLogic",
"HardwareTriggeredFlyable",
"CalculateTimeout",
"CalculatableTimeout",
"DEFAULT_TIMEOUT",
"Callback",
"NotConnected",
Expand Down
39 changes: 11 additions & 28 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Awaitable,
Callable,
Generic,
SupportsFloat,
Type,
TypeVar,
cast,
Expand All @@ -27,15 +26,11 @@
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""

def __init__(self, awaitable: Awaitable, timeout: SupportsFloat | None = None):
if isinstance(timeout, SupportsFloat):
timeout = float(timeout)
def __init__(self, awaitable: Awaitable):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(
asyncio.wait_for(awaitable, timeout=timeout)
)
self.task = asyncio.create_task(awaitable)
self.task.add_done_callback(self._run_callbacks)
self._callbacks: list[Callback[Status]] = []

Expand All @@ -49,9 +44,8 @@ def add_callback(self, callback: Callback[Status]):
self._callbacks.append(callback)

def _run_callbacks(self, task: asyncio.Task):
if not task.cancelled():
for callback in self._callbacks:
callback(self)
for callback in self._callbacks:
callback(self)

def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
Expand Down Expand Up @@ -93,14 +87,11 @@ def __repr__(self) -> str:
class AsyncStatus(AsyncStatusBase):
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
"""Wrap an async function in an AsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
# We can't type this more properly because Concatenate/ParamSpec doesn't
# yet support keywords
# https://peps.python.org/pep-0612/#concatenating-keyword-parameters
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)
return cls(f(*args, **kwargs))

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
Expand All @@ -110,15 +101,11 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(
self,
iterator: AsyncIterator[WatcherUpdate[T]],
timeout: SupportsFloat | None = None,
):
def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator), timeout)
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
Expand Down Expand Up @@ -146,14 +133,10 @@ def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes
'timeout' as an argument, this must be a float or None, and it
will be propagated to the status."""
"""Wrap an AsyncIterator in a WatchableAsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)
return cls(f(*args, **kwargs))

return cast(Callable[P, WAS], wrap_f)
40 changes: 23 additions & 17 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .device import Device
from .signal_backend import SignalBackend
from .soft_signal_backend import SoftSignalBackend
from .utils import DEFAULT_TIMEOUT, Callback, T
from .utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T


def _add_timeout(func):
Expand Down Expand Up @@ -213,15 +213,14 @@ async def unstage(self) -> None:
self._del_cache(self._get_cache().set_staged(False))


USE_DEFAULT_TIMEOUT = "USE_DEFAULT_TIMEOUT"


class SignalW(Signal[T], Movable):
"""Signal that can be set"""

def set(self, value: T, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
def set(
self, value: T, wait=True, timeout: CalculatableTimeout = CalculateTimeout
) -> AsyncStatus:
"""Set the value and return a status saying when it's done"""
if timeout is USE_DEFAULT_TIMEOUT:
if timeout is CalculateTimeout:
timeout = self._timeout

async def do_set():
Expand All @@ -248,9 +247,11 @@ async def locate(self) -> Location:
class SignalX(Signal):
"""Signal that puts the default value"""

def trigger(self, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
def trigger(
self, wait=True, timeout: CalculatableTimeout = CalculateTimeout
) -> AsyncStatus:
"""Trigger the action and return a status saying when it's done"""
if timeout is USE_DEFAULT_TIMEOUT:
if timeout is CalculateTimeout:
timeout = self._timeout
coro = self._backend.put(None, wait=wait, timeout=timeout)
return AsyncStatus(coro)
Expand All @@ -270,7 +271,7 @@ def soft_signal_r_and_setter(
datatype: Optional[Type[T]] = None,
initial_value: Optional[T] = None,
name: str = "",
) -> Tuple[SignalR[T], Callable[[T]]]:
) -> Tuple[SignalR[T], Callable[[T], None]]:
"""Returns a tuple of a read-only Signal and a callable through
which the signal can be internally modified within the device. Use
soft_signal_rw if you want a device that is externally modifiable
Expand Down Expand Up @@ -394,7 +395,7 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int):


async def observe_value(
signal: SignalR[T], timeout=None, done_status: Status | None = None
signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None
) -> AsyncGenerator[T, None]:
"""Subscribe to the value of a signal so it can be iterated from.

Expand All @@ -403,8 +404,12 @@ async def observe_value(
signal:
Call subscribe_value on this at the start, and clear_sub on it at the
end
timeout:
If given, how long to wait for each updated value in seconds. If an update
is not produced in this time then raise asyncio.TimeoutError
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.

Notes
-----
Expand All @@ -414,9 +419,7 @@ async def observe_value(
do_something_with(value)
"""

class StatusIsDone: ...

q: asyncio.Queue[T | StatusIsDone] = asyncio.Queue()
q: asyncio.Queue[T | Status] = asyncio.Queue()
if timeout is None:
get_value = q.get
else:
Expand All @@ -425,16 +428,19 @@ async def get_value():
return await asyncio.wait_for(q.get(), timeout)

if done_status is not None:
done_status.add_callback(lambda _: q.put_nowait(StatusIsDone()))
done_status.add_callback(q.put_nowait)

signal.subscribe_value(q.put_nowait)
try:
while True:
item = await get_value()
if not isinstance(item, StatusIsDone):
yield item
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
else:
break
else:
break
yield item
finally:
signal.clear_sub(q.put_nowait)

Expand Down
11 changes: 11 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
ErrorText = Union[str, Dict[str, Exception]]


class CalculateTimeout:
"""Sentinel class used to implement ``myfunc(timeout=CalculateTimeout)``

This signifies that the function should calculate a suitable non-zero
timeout itself
"""


CalculatableTimeout = float | None | Type[CalculateTimeout]


class NotConnected(Exception):
"""Exception to be raised if a `Device.connect` is cancelled"""

Expand Down
56 changes: 30 additions & 26 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import string
import subprocess
import sys
from dataclasses import replace
from enum import Enum
from pathlib import Path

Expand All @@ -22,7 +21,13 @@
WatchableAsyncStatus,
observe_value,
)
from ophyd_async.core.utils import WatcherUpdate
from ophyd_async.core.async_status import AsyncStatus
from ophyd_async.core.utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
CalculateTimeout,
WatcherUpdate,
)

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -66,11 +71,9 @@ def __init__(self, prefix: str, name="") -> None:
# Define some signals
with self.add_children_as_readables(HintedSignal):
self.readback = epics_signal_r(float, prefix + "Readback")

with self.add_children_as_readables(ConfigSignal):
self.velocity = epics_signal_rw(float, prefix + "Velocity")
self.units = epics_signal_r(str, prefix + "Readback.EGU")

self.setpoint = epics_signal_rw(float, prefix + "Setpoint")
self.precision = epics_signal_r(int, prefix + "Readback.PREC")
# Signals that collide with standard methods should have a trailing underscore
Expand All @@ -85,41 +88,42 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.readback.set_name(name)

async def _move(self, new_position: float):
@WatchableAsyncStatus.wrap
async def set(
self, new_position: float, timeout: CalculatableTimeout = CalculateTimeout
):
self._set_success = True
# time.monotonic won't go backwards in case of NTP corrections
old_position, units, precision = await asyncio.gather(
old_position, units, precision, velocity = await asyncio.gather(
self.setpoint.get_value(),
self.units.get_value(),
self.precision.get_value(),
self.velocity.get_value(),
)
if timeout is CalculateTimeout:
assert velocity > 0, "Mover has zero velocity"
timeout = abs(new_position - old_position) / velocity + DEFAULT_TIMEOUT
# Make an Event that will be set on completion, and a Status that will
# error if not done in time
done = asyncio.Event()
done_status = AsyncStatus(asyncio.wait_for(done.wait(), timeout))
# Wait for the value to set, but don't wait for put completion callback
move_status = self.setpoint.set(new_position, wait=True)
if not self._set_success:
raise RuntimeError("Motor was stopped")
# return a template to set() which it can use to yield progress updates
return (
WatcherUpdate(
await self.setpoint.set(new_position, wait=False)
async for current_position in observe_value(
self.readback, done_status=done_status
):
yield WatcherUpdate(
current=current_position,
initial=old_position,
current=old_position,
target=new_position,
name=self.name,
unit=units,
precision=precision,
),
move_status,
)

@WatchableAsyncStatus.wrap # uses the timeout argument from the function it wraps
async def set(self, new_position: float, timeout: float | None = None):
update, _ = await self._move(new_position)
async for current_position in observe_value(self.readback):
yield replace(
update,
name=self.name,
current=current_position,
)
if np.isclose(current_position, new_position):
done.set()
break
if not self._set_success:
raise RuntimeError("Motor was stopped")

async def stop(self, success=True):
self._set_success = success
Expand Down
Loading
Loading