Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move timeout handling to functions wrapped with AsyncStatus #318

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Awaitable,
Callable,
Generic,
SupportsFloat,
Type,
TypeVar,
cast,
Expand All @@ -27,15 +26,11 @@
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""

def __init__(self, awaitable: Awaitable, timeout: SupportsFloat | None = None):
if isinstance(timeout, SupportsFloat):
timeout = float(timeout)
def __init__(self, awaitable: Awaitable):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(
asyncio.wait_for(awaitable, timeout=timeout)
)
self.task = asyncio.create_task(awaitable)
self.task.add_done_callback(self._run_callbacks)
self._callbacks: list[Callback[Status]] = []

Expand All @@ -49,9 +44,8 @@ def add_callback(self, callback: Callback[Status]):
self._callbacks.append(callback)

def _run_callbacks(self, task: asyncio.Task):
if not task.cancelled():
for callback in self._callbacks:
callback(self)
for callback in self._callbacks:
callback(self)

def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
Expand Down Expand Up @@ -93,14 +87,11 @@ def __repr__(self) -> str:
class AsyncStatus(AsyncStatusBase):
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
"""Wrap an async function in an AsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
# We can't type this more properly because Concatenate/ParamSpec doesn't
# yet support keywords
# https://peps.python.org/pep-0612/#concatenating-keyword-parameters
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)
return cls(f(*args, **kwargs))

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
Expand All @@ -110,15 +101,11 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(
self,
iterator: AsyncIterator[WatcherUpdate[T]],
timeout: SupportsFloat | None = None,
):
def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator), timeout)
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
Expand Down Expand Up @@ -146,14 +133,10 @@ def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes
'timeout' as an argument, this must be a float or None, and it
will be propagated to the status."""
"""Wrap an AsyncIterator in a WatchableAsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)
return cls(f(*args, **kwargs))

return cast(Callable[P, WAS], wrap_f)
23 changes: 14 additions & 9 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def soft_signal_r_and_setter(
datatype: Optional[Type[T]] = None,
initial_value: Optional[T] = None,
name: str = "",
) -> Tuple[SignalR[T], Callable[[T]]]:
) -> Tuple[SignalR[T], Callable[[T], None]]:
"""Returns a tuple of a read-only Signal and a callable through
which the signal can be internally modified within the device. Use
soft_signal_rw if you want a device that is externally modifiable
Expand Down Expand Up @@ -394,7 +394,7 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int):


async def observe_value(
signal: SignalR[T], timeout=None, done_status: Status | None = None
signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None
) -> AsyncGenerator[T, None]:
"""Subscribe to the value of a signal so it can be iterated from.

Expand All @@ -403,8 +403,12 @@ async def observe_value(
signal:
Call subscribe_value on this at the start, and clear_sub on it at the
end
timeout:
If given, how long to wait for each updated value in seconds. If an update
is not produced in this time then raise asyncio.TimeoutError
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.

Notes
-----
Expand All @@ -414,9 +418,7 @@ async def observe_value(
do_something_with(value)
"""

class StatusIsDone: ...

q: asyncio.Queue[T | StatusIsDone] = asyncio.Queue()
q: asyncio.Queue[T | Status] = asyncio.Queue()
if timeout is None:
get_value = q.get
else:
Expand All @@ -425,16 +427,19 @@ async def get_value():
return await asyncio.wait_for(q.get(), timeout)

if done_status is not None:
done_status.add_callback(lambda _: q.put_nowait(StatusIsDone()))
done_status.add_callback(q.put_nowait)

signal.subscribe_value(q.put_nowait)
try:
while True:
item = await get_value()
if not isinstance(item, StatusIsDone):
yield item
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
else:
break
else:
break
yield item
finally:
signal.clear_sub(q.put_nowait)

Expand Down
50 changes: 24 additions & 26 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import string
import subprocess
import sys
from dataclasses import replace
from enum import Enum
from pathlib import Path

Expand All @@ -22,7 +21,8 @@
WatchableAsyncStatus,
observe_value,
)
from ophyd_async.core.utils import WatcherUpdate
from ophyd_async.core.async_status import AsyncStatus
from ophyd_async.core.utils import DEFAULT_TIMEOUT, WatcherUpdate

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -66,11 +66,9 @@ def __init__(self, prefix: str, name="") -> None:
# Define some signals
with self.add_children_as_readables(HintedSignal):
self.readback = epics_signal_r(float, prefix + "Readback")

with self.add_children_as_readables(ConfigSignal):
self.velocity = epics_signal_rw(float, prefix + "Velocity")
self.units = epics_signal_r(str, prefix + "Readback.EGU")

self.setpoint = epics_signal_rw(float, prefix + "Setpoint")
self.precision = epics_signal_r(int, prefix + "Readback.PREC")
# Signals that collide with standard methods should have a trailing underscore
Expand All @@ -85,41 +83,41 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.readback.set_name(name)

async def _move(self, new_position: float):
@WatchableAsyncStatus.wrap
async def set(self, new_position: float):
self._set_success = True
# time.monotonic won't go backwards in case of NTP corrections
old_position, units, precision = await asyncio.gather(
old_position, units, precision, velocity = await asyncio.gather(
self.setpoint.get_value(),
self.units.get_value(),
self.precision.get_value(),
self.velocity.get_value(),
)
assert velocity > 0, "Mover has zero velocity"
move_time = abs(new_position - old_position) / velocity
# Make an Event that will be set on completion, and a Status that will
# error if not done in time
done = asyncio.Event()
done_status = AsyncStatus(
asyncio.wait_for(done.wait(), move_time + DEFAULT_TIMEOUT)
)
# Wait for the value to set, but don't wait for put completion callback
move_status = self.setpoint.set(new_position, wait=True)
if not self._set_success:
raise RuntimeError("Motor was stopped")
# return a template to set() which it can use to yield progress updates
return (
WatcherUpdate(
await self.setpoint.set(new_position, wait=False)
async for current_position in observe_value(
self.readback, done_status=done_status
):
yield WatcherUpdate(
current=current_position,
initial=old_position,
current=old_position,
target=new_position,
name=self.name,
unit=units,
precision=precision,
),
move_status,
)

@WatchableAsyncStatus.wrap # uses the timeout argument from the function it wraps
async def set(self, new_position: float, timeout: float | None = None):
update, _ = await self._move(new_position)
async for current_position in observe_value(self.readback):
yield replace(
update,
name=self.name,
current=current_position,
)
if np.isclose(current_position, new_position):
done.set()
break
if not self._set_success:
raise RuntimeError("Motor was stopped")

async def stop(self, success=True):
self._set_success = success
Expand Down
52 changes: 24 additions & 28 deletions src/ophyd_async/epics/motion/motor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import asyncio
from dataclasses import replace

from bluesky.protocols import Movable, Stoppable

from ophyd_async.core import (
AsyncStatus,
ConfigSignal,
HintedSignal,
StandardReadable,
WatchableAsyncStatus,
)
from ophyd_async.core.signal import observe_value
from ophyd_async.core.utils import WatcherUpdate
from ophyd_async.core.utils import DEFAULT_TIMEOUT, WatcherUpdate

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -47,42 +45,40 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.user_readback.set_name(name)

async def _move(
self, new_position: float
) -> tuple[WatcherUpdate[float], AsyncStatus]:
@WatchableAsyncStatus.wrap
async def set(self, new_position: float):
self._set_success = True
old_position, units, precision = await asyncio.gather(
(
old_position,
units,
precision,
velocity,
acceleration_time,
) = await asyncio.gather(
self.user_setpoint.get_value(),
self.motor_egu.get_value(),
self.precision.get_value(),
self.velocity.get_value(),
self.acceleration_time.get_value(),
)
move_status = self.user_setpoint.set(new_position, wait=True)
if not self._set_success:
raise RuntimeError("Motor was stopped")
return (
WatcherUpdate(
initial=old_position,
current=old_position,
target=new_position,
unit=units,
precision=precision,
),
move_status,
assert velocity > 0, "Motor has zero velocity"
move_time = abs(new_position - old_position) / velocity + 2 * acceleration_time
move_status = self.user_setpoint.set(
new_position, wait=True, timeout=move_time + DEFAULT_TIMEOUT
)

@WatchableAsyncStatus.wrap
async def set(self, new_position: float, timeout: float | None = None):
update, move_status = await self._move(new_position)
async for current_position in observe_value(
self.user_readback, done_status=move_status
):
if not self._set_success:
raise RuntimeError("Motor was stopped")
yield replace(
update,
name=self.name,
yield WatcherUpdate(
current=current_position,
initial=old_position,
target=new_position,
name=self.name,
unit=units,
precision=precision,
)
if not self._set_success:
raise RuntimeError("Motor was stopped")

async def stop(self, success=False):
self._set_success = success
Expand Down
Loading