diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 9577d7c6a2..f49e127df2 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -5,7 +5,7 @@ ShapeProvider, StaticDirectoryProvider, ) -from .async_status import AsyncStatus +from .async_status import AsyncStatus, WatchableAsyncStatus from .detector import ( DetectorControl, DetectorTrigger, @@ -96,6 +96,7 @@ "set_mock_value", "wait_for_value", "AsyncStatus", + "WatchableAsyncStatus", "DirectoryInfo", "DirectoryProvider", "NameProvider", diff --git a/src/ophyd_async/core/async_status.py b/src/ophyd_async/core/async_status.py index 2cdd3e804c..19e3f1d201 100644 --- a/src/ophyd_async/core/async_status.py +++ b/src/ophyd_async/core/async_status.py @@ -2,30 +2,42 @@ import asyncio import functools -from typing import Awaitable, Callable, Coroutine, List, Optional, cast +import time +from dataclasses import asdict, replace +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Generic, + SupportsFloat, + Type, + TypeVar, + cast, +) from bluesky.protocols import Status -from .utils import Callback, T +from ..protocols import Watcher +from .utils import Callback, P, T, WatcherUpdate +AS = TypeVar("AS", bound="AsyncStatus") +WAS = TypeVar("WAS", bound="WatchableAsyncStatus") -class AsyncStatus(Status): + +class AsyncStatusBase(Status): """Convert asyncio awaitable to bluesky Status interface""" - def __init__( - self, - awaitable: Awaitable, - watchers: Optional[List[Callable]] = None, - ): + def __init__(self, awaitable: Awaitable, timeout: SupportsFloat | None = None): + if isinstance(timeout, SupportsFloat): + timeout = float(timeout) if isinstance(awaitable, asyncio.Task): self.task = awaitable else: - self.task = asyncio.create_task(awaitable) # type: ignore - + self.task = asyncio.create_task( + asyncio.wait_for(awaitable, timeout=timeout) + ) self.task.add_done_callback(self._run_callbacks) - - self._callbacks = cast(List[Callback[Status]], []) - self._watchers = watchers + self._callbacks: list[Callback[Status]] = [] def __await__(self): return self.task.__await__() @@ -41,15 +53,11 @@ def _run_callbacks(self, task: asyncio.Task): for callback in self._callbacks: callback(self) - # TODO: remove ignore and bump min version when bluesky v1.12.0 is released - def exception( # type: ignore - self, timeout: Optional[float] = 0.0 - ) -> Optional[BaseException]: + def exception(self, timeout: float | None = 0.0) -> BaseException | None: if timeout != 0.0: - raise Exception( + raise ValueError( "cannot honour any timeout other than 0 in an asynchronous function" ) - if self.task.done(): try: return self.task.exception() @@ -69,22 +77,6 @@ def success(self) -> bool: and self.task.exception() is None ) - def watch(self, watcher: Callable): - """Add watcher to the list of interested parties. - - Arguments as per Bluesky :external+bluesky:meth:`watch` protocol. - """ - if self._watchers is not None: - self._watchers.append(watcher) - - @classmethod - def wrap(cls, f: Callable[[T], Coroutine]) -> Callable[[T], "AsyncStatus"]: - @functools.wraps(f) - def wrap_f(self) -> AsyncStatus: - return AsyncStatus(f(self)) - - return wrap_f - def __repr__(self) -> str: if self.done: if e := self.exception(): @@ -96,3 +88,72 @@ def __repr__(self) -> str: return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>" __str__ = __repr__ + + +class AsyncStatus(AsyncStatusBase): + @classmethod + def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]: + @functools.wraps(f) + def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: + # We can't type this more properly because Concatenate/ParamSpec doesn't + # yet support keywords + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters + timeout = kwargs.get("timeout") + assert isinstance(timeout, SupportsFloat) or timeout is None + return cls(f(*args, **kwargs), timeout=timeout) + + # type is actually functools._Wrapped[P, Awaitable, P, AS] + # but functools._Wrapped is not necessarily available + return cast(Callable[P, AS], wrap_f) + + +class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): + """Convert AsyncIterator of WatcherUpdates to bluesky Status interface.""" + + def __init__( + self, + iterator: AsyncIterator[WatcherUpdate[T]], + timeout: SupportsFloat | None = None, + ): + self._watchers: list[Watcher] = [] + self._start = time.monotonic() + self._last_update: WatcherUpdate[T] | None = None + super().__init__(self._notify_watchers_from(iterator), timeout) + + async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): + async for update in iterator: + self._last_update = ( + update + if update.time_elapsed is not None + else replace(update, time_elapsed=time.monotonic() - self._start) + ) + for watcher in self._watchers: + self._update_watcher(watcher, self._last_update) + + def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]): + vals = asdict( + update, dict_factory=lambda d: {k: v for k, v in d if v is not None} + ) + watcher(**vals) + + def watch(self, watcher: Watcher): + self._watchers.append(watcher) + if self._last_update: + self._update_watcher(watcher, self._last_update) + + @classmethod + def wrap( + cls: Type[WAS], + f: Callable[P, AsyncIterator[WatcherUpdate[T]]], + ) -> Callable[P, WAS]: + """Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes + 'timeout' as an argument, this must be a float or None, and it + will be propagated to the status.""" + + @functools.wraps(f) + def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: + timeout = kwargs.get("timeout") + assert isinstance(timeout, SupportsFloat) or timeout is None + return cls(f(*args, **kwargs), timeout=timeout) + + return cast(Callable[P, WAS], wrap_f) diff --git a/src/ophyd_async/core/detector.py b/src/ophyd_async/core/detector.py index 99b3c8f95e..ee168f8830 100644 --- a/src/ophyd_async/core/detector.py +++ b/src/ophyd_async/core/detector.py @@ -31,9 +31,9 @@ from ophyd_async.protocols import AsyncConfigurable, AsyncReadable -from .async_status import AsyncStatus +from .async_status import AsyncStatus, WatchableAsyncStatus from .device import Device -from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts +from .utils import DEFAULT_TIMEOUT, WatcherUpdate, merge_gathered_dicts T = TypeVar("T") @@ -188,7 +188,7 @@ def __init__( self._trigger_info: Optional[TriggerInfo] = None # For kickoff self._watchers: List[Callable] = [] - self._fly_status: Optional[AsyncStatus] = None + self._fly_status: Optional[WatchableAsyncStatus] = None self._fly_start: float self._intial_frame: int @@ -292,43 +292,37 @@ async def _prepare(self, value: T) -> None: f"Detector {self.controller} needs at least {required}s deadtime, " f"but trigger logic provides only {self._trigger_info.deadtime}s" ) - self._arm_status = await self.controller.arm( num=self._trigger_info.num, trigger=self._trigger_info.trigger, exposure=self._trigger_info.livetime, ) - - @AsyncStatus.wrap - async def kickoff(self) -> None: - self._fly_status = AsyncStatus(self._fly(), self._watchers) self._fly_start = time.monotonic() - async def _fly(self) -> None: - await self._observe_writer_indicies(self._last_frame) - - async def _observe_writer_indicies(self, end_observation: int): + @AsyncStatus.wrap + async def kickoff(self): + if not self._arm_status: + raise Exception("Detector not armed!") + + @WatchableAsyncStatus.wrap + async def complete(self): + assert self._arm_status, "Prepare not run" + assert self._trigger_info async for index in self.writer.observe_indices_written( self._frame_writing_timeout ): - for watcher in self._watchers: - watcher( - name=self.name, - current=index, - initial=self._initial_frame, - target=end_observation, - unit="", - precision=0, - time_elapsed=time.monotonic() - self._fly_start, - ) - if index >= end_observation: + yield WatcherUpdate( + name=self.name, + current=index, + initial=self._initial_frame, + target=self._trigger_info.num, + unit="", + precision=0, + time_elapsed=time.monotonic() - self._fly_start, + ) + if index >= self._trigger_info.num: break - @AsyncStatus.wrap - async def complete(self) -> AsyncStatus: - assert self._fly_status, "Kickoff not run" - return await self._fly_status - async def describe_collect(self) -> Dict[str, DataKey]: return self._describe diff --git a/src/ophyd_async/core/signal.py b/src/ophyd_async/core/signal.py index 2632b9f916..59b849b714 100644 --- a/src/ophyd_async/core/signal.py +++ b/src/ophyd_async/core/signal.py @@ -21,6 +21,7 @@ Location, Movable, Reading, + Status, Subscribable, ) @@ -390,7 +391,9 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int): ) -async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, None]: +async def observe_value( + signal: SignalR[T], timeout=None, done_status: Status | None = None +) -> AsyncGenerator[T, None]: """Subscribe to the value of a signal so it can be iterated from. Parameters @@ -398,6 +401,8 @@ async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, N signal: Call subscribe_value on this at the start, and clear_sub on it at the end + done_status: + If this status is complete, stop observing and make the iterator return. Notes ----- @@ -406,7 +411,10 @@ async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, N async for value in observe_value(sig): do_something_with(value) """ - q: asyncio.Queue[T] = asyncio.Queue() + + class StatusIsDone: ... + + q: asyncio.Queue[T | StatusIsDone] = asyncio.Queue() if timeout is None: get_value = q.get else: @@ -414,10 +422,17 @@ async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, N async def get_value(): return await asyncio.wait_for(q.get(), timeout) + if done_status is not None: + done_status.add_callback(lambda _: q.put_nowait(StatusIsDone())) + signal.subscribe_value(q.put_nowait) try: while True: - yield await get_value() + item = await get_value() + if not isinstance(item, StatusIsDone): + yield item + else: + break finally: signal.clear_sub(q.put_nowait) diff --git a/src/ophyd_async/core/utils.py b/src/ophyd_async/core/utils.py index ad70bcb62e..b09b9322a1 100644 --- a/src/ophyd_async/core/utils.py +++ b/src/ophyd_async/core/utils.py @@ -2,13 +2,16 @@ import asyncio import logging +from dataclasses import dataclass from typing import ( Awaitable, Callable, Dict, + Generic, Iterable, List, Optional, + ParamSpec, Type, TypeVar, Union, @@ -18,6 +21,7 @@ from bluesky.protocols import Reading T = TypeVar("T") +P = ParamSpec("P") Callback = Callable[[T], None] #: A function that will be called with the Reading and value when the @@ -77,6 +81,21 @@ def __str__(self) -> str: return self.format_error_string(indent="") +@dataclass(frozen=True) +class WatcherUpdate(Generic[T]): + """A dataclass such that, when expanded, it provides the kwargs for a watcher""" + + current: T + initial: T + target: T + name: str | None = None + unit: str | None = None + precision: float | None = None + fraction: float | None = None + time_elapsed: float | None = None + time_remaining: float | None = None + + async def wait_for_connection(**coros: Awaitable[None]): """Call many underlying signals, accumulating exceptions and returning them diff --git a/src/ophyd_async/epics/demo/__init__.py b/src/ophyd_async/epics/demo/__init__.py index 83ddba7dec..4e183d8548 100644 --- a/src/ophyd_async/epics/demo/__init__.py +++ b/src/ophyd_async/epics/demo/__init__.py @@ -6,23 +6,23 @@ import string import subprocess import sys -import time +from dataclasses import replace from enum import Enum from pathlib import Path -from typing import Callable, List, Optional import numpy as np from bluesky.protocols import Movable, Stoppable from ophyd_async.core import ( - AsyncStatus, ConfigSignal, Device, DeviceVector, HintedSignal, StandardReadable, + WatchableAsyncStatus, observe_value, ) +from ophyd_async.core.utils import WatcherUpdate from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -85,46 +85,41 @@ def set_name(self, name: str): # Readback should be named the same as its parent in read() self.readback.set_name(name) - async def _move(self, new_position: float, watchers: List[Callable] = []): + async def _move(self, new_position: float): self._set_success = True # time.monotonic won't go backwards in case of NTP corrections - start = time.monotonic() old_position, units, precision = await asyncio.gather( self.setpoint.get_value(), self.units.get_value(), self.precision.get_value(), ) # Wait for the value to set, but don't wait for put completion callback - await self.setpoint.set(new_position, wait=False) - async for current_position in observe_value(self.readback): - for watcher in watchers: - watcher( - name=self.name, - current=current_position, - initial=old_position, - target=new_position, - unit=units, - precision=precision, - time_elapsed=time.monotonic() - start, - ) - if np.isclose(current_position, new_position): - break + move_status = self.setpoint.set(new_position, wait=True) if not self._set_success: raise RuntimeError("Motor was stopped") + # return a template to set() which it can use to yield progress updates + return ( + WatcherUpdate( + initial=old_position, + current=old_position, + target=new_position, + unit=units, + precision=precision, + ), + move_status, + ) - def move(self, new_position: float, timeout: Optional[float] = None): - """Commandline only synchronous move of a Motor""" - from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop - - if in_bluesky_event_loop(): - raise RuntimeError("Will deadlock run engine if run in a plan") - call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore - - # TODO: this fails if we call from the cli, but works if we "ipython await" it - def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus: - watchers: List[Callable] = [] - coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout) - return AsyncStatus(coro, watchers) + @WatchableAsyncStatus.wrap # uses the timeout argument from the function it wraps + async def set(self, new_position: float, timeout: float | None = None): + update, _ = await self._move(new_position) + async for current_position in observe_value(self.readback): + yield replace( + update, + name=self.name, + current=current_position, + ) + if np.isclose(current_position, new_position): + break async def stop(self, success=True): self._set_success = success diff --git a/src/ophyd_async/epics/motion/motor.py b/src/ophyd_async/epics/motion/motor.py index 33d3a41062..4310a58628 100644 --- a/src/ophyd_async/epics/motion/motor.py +++ b/src/ophyd_async/epics/motion/motor.py @@ -1,10 +1,17 @@ import asyncio -import time -from typing import Callable, List, Optional +from dataclasses import replace from bluesky.protocols import Movable, Stoppable -from ophyd_async.core import AsyncStatus, ConfigSignal, HintedSignal, StandardReadable +from ophyd_async.core import ( + AsyncStatus, + ConfigSignal, + HintedSignal, + StandardReadable, + WatchableAsyncStatus, +) +from ophyd_async.core.signal import observe_value +from ophyd_async.core.utils import WatcherUpdate from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -41,54 +48,46 @@ def set_name(self, name: str): self.user_readback.set_name(name) async def _move( - self, new_position: float, watchers: Optional[List[Callable]] = None - ): - if watchers is None: - watchers = [] + self, new_position: float + ) -> tuple[WatcherUpdate[float], AsyncStatus]: self._set_success = True - start = time.monotonic() old_position, units, precision = await asyncio.gather( self.user_setpoint.get_value(), self.motor_egu.get_value(), self.precision.get_value(), ) - - def update_watchers(current_position: float): - for watcher in watchers: - watcher( - name=self.name, - current=current_position, - initial=old_position, - target=new_position, - unit=units, - precision=precision, - time_elapsed=time.monotonic() - start, - ) - - self.user_readback.subscribe_value(update_watchers) - try: - await self.user_setpoint.set(new_position) - finally: - self.user_readback.clear_sub(update_watchers) + move_status = self.user_setpoint.set(new_position, wait=True) if not self._set_success: raise RuntimeError("Motor was stopped") + return ( + WatcherUpdate( + initial=old_position, + current=old_position, + target=new_position, + unit=units, + precision=precision, + ), + move_status, + ) - def move(self, new_position: float, timeout: Optional[float] = None): - """Commandline only synchronous move of a Motor""" - from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop - - if in_bluesky_event_loop(): - raise RuntimeError("Will deadlock run engine if run in a plan") - call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore - - def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus: - watchers: List[Callable] = [] - coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout) - return AsyncStatus(coro, watchers) + @WatchableAsyncStatus.wrap + async def set(self, new_position: float, timeout: float | None = None): + update, move_status = await self._move(new_position) + async for current_position in observe_value( + self.user_readback, done_status=move_status + ): + if not self._set_success: + raise RuntimeError("Motor was stopped") + yield replace( + update, + name=self.name, + current=current_position, + ) async def stop(self, success=False): self._set_success = success # Put with completion will never complete as we are waiting for completion on # the move above, so need to pass wait=False - status = self.motor_stop.trigger(wait=False) - await status + await self.motor_stop.trigger(wait=False) + # Trigger any callbacks + await self.user_readback._backend.put(await self.user_readback.get_value()) diff --git a/src/ophyd_async/protocols.py b/src/ophyd_async/protocols.py index f2d7ec4ab6..05256e46c8 100644 --- a/src/ophyd_async/protocols.py +++ b/src/ophyd_async/protocols.py @@ -1,9 +1,20 @@ +from __future__ import annotations + from abc import abstractmethod -from typing import Dict, Protocol, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) from bluesky.protocols import DataKey, HasName, Reading -from ophyd_async.core.async_status import AsyncStatus +if TYPE_CHECKING: + from ophyd_async.core.async_status import AsyncStatus @runtime_checkable @@ -94,3 +105,22 @@ def unstage(self) -> AsyncStatus: unstaging. """ ... + + +C = TypeVar("C", contravariant=True) + + +class Watcher(Protocol, Generic[C]): + @staticmethod + def __call__( + *, + current: C, + initial: C, + target: C, + name: str | None, + unit: str | None, + precision: float | None, + fraction: float | None, + time_elapsed: float | None, + time_remaining: float | None, + ) -> Any: ... diff --git a/src/ophyd_async/sim/demo/sim_motor.py b/src/ophyd_async/sim/demo/sim_motor.py index ce41354ac7..aeb4f38844 100644 --- a/src/ophyd_async/sim/demo/sim_motor.py +++ b/src/ophyd_async/sim/demo/sim_motor.py @@ -1,13 +1,18 @@ import asyncio import time -from typing import Callable, List, Optional +from dataclasses import replace from bluesky.protocols import Movable, Stoppable from ophyd_async.core import StandardReadable -from ophyd_async.core.async_status import AsyncStatus -from ophyd_async.core.signal import soft_signal_r_and_setter, soft_signal_rw +from ophyd_async.core.async_status import AsyncStatus, WatchableAsyncStatus +from ophyd_async.core.signal import ( + observe_value, + soft_signal_r_and_setter, + soft_signal_rw, +) from ophyd_async.core.standard_readable import ConfigSignal, HintedSignal +from ophyd_async.core.utils import WatcherUpdate class SimMotor(StandardReadable, Movable, Stoppable): @@ -27,10 +32,10 @@ def __init__(self, name="", instant=True) -> None: with self.add_children_as_readables(ConfigSignal): self.velocity = soft_signal_rw(float, 1.0) - self.egu = soft_signal_rw(float, "mm") + self.egu = soft_signal_rw(str, "mm") self._instant = instant - self._move_task: Optional[asyncio.Task] = None + self._move_status: AsyncStatus | None = None # Define some signals self.user_setpoint = soft_signal_rw(float, 0) @@ -44,21 +49,37 @@ def stop(self, success=False): """ Stop the motor if it is moving """ - if self._move_task: - self._move_task.cancel() - self._move_task = None + if self._move_status: + self._move_status.task.cancel() + self._move_status = None + + async def trigger_callbacks(): + await self.user_readback._backend.put( + await self.user_readback._backend.get_value() + ) + + asyncio.create_task(trigger_callbacks()) self._set_success = success - def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus: # noqa: F821 + @WatchableAsyncStatus.wrap + async def set(self, new_position: float, timeout: float | None = None): """ Asynchronously move the motor to a new position. """ - watchers: List[Callable] = [] - coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout) - return AsyncStatus(coro, watchers) + update, move_status = await self._move(new_position, timeout) + async for current_position in observe_value( + self.user_readback, done_status=move_status + ): + if not self._set_success: + raise RuntimeError("Motor was stopped") + yield replace( + update, + name=self.name, + current=current_position, + ) - async def _move(self, new_position: float, watchers: List[Callable] = []): + async def _move(self, new_position: float, timeout: float | None = None): """ Start the motor moving to a new position. @@ -67,6 +88,7 @@ async def _move(self, new_position: float, watchers: List[Callable] = []): """ self.stop() start = time.monotonic() + self._set_success = True current_position = await self.user_readback.get_value() distance = abs(new_position - current_position) @@ -94,25 +116,18 @@ async def update_position(): self._user_readback_set(current_position) - # notify watchers of the new position - for watcher in watchers: - watcher( - name=self.name, - current=current_position, - initial=old_position, - target=new_position, - unit=units, - time_elapsed=time.monotonic() - start, - ) - # 10hz update loop await asyncio.sleep(0.1) - # set up a task that updates the motor position at 10hz - self._move_task = asyncio.create_task(update_position()) - - try: - await self._move_task - finally: - if not self._set_success: - raise RuntimeError("Motor was stopped") + # set up a task that updates the motor position at ~10hz + self._move_status = AsyncStatus(asyncio.wait_for(update_position(), timeout)) + + return ( + WatcherUpdate( + initial=old_position, + current=old_position, + target=new_position, + unit=units, + ), + self._move_status, + ) diff --git a/tests/core/test_async_status.py b/tests/core/test_async_status.py index 78e09abb4a..842df32136 100644 --- a/tests/core/test_async_status.py +++ b/tests/core/test_async_status.py @@ -50,13 +50,17 @@ async def test_async_status_has_no_exception_if_coroutine_successful(normal_coro async def test_async_status_success_if_cancelled(normal_coroutine): - status = AsyncStatus(normal_coroutine()) + coro = normal_coroutine() + status = AsyncStatus(coro) assert status.exception() is None status.task.cancel() with pytest.raises(asyncio.CancelledError): await status assert status.success is False assert isinstance(status.exception(), asyncio.CancelledError) + # asyncio will RuntimeWarning us about this never being awaited if we don't. + # RunEngine handled this as a special case + await coro async def coroutine_to_wrap(time: float): @@ -122,7 +126,7 @@ def set(self, value) -> AsyncStatus: async def test_status_propogates_traceback_under_RE(RE) -> None: - expected_call_stack = ["_set", "_fail"] + expected_call_stack = ["wait_for", "_set", "_fail"] d = FailingMovable() with pytest.raises(FailedStatus) as ctx: RE(bps.mv(d, 3)) @@ -145,3 +149,41 @@ async def test_async_status_exception_timeout(): st = AsyncStatus(asyncio.sleep(0.1)) with pytest.raises(Exception): st.exception(timeout=1.0) + + +@pytest.fixture +def loop(): + return asyncio.get_event_loop() + + +def test_asyncstatus_wraps_bare_func(loop): + async def do_test(): + @AsyncStatus.wrap + async def coro_status(): + await asyncio.sleep(0.01) + + st = coro_status() + assert isinstance(st, AsyncStatus) + await asyncio.wait_for(st.task, None) + assert st.done + + loop.run_until_complete(do_test()) + + +def test_asyncstatus_wraps_bare_func_with_args_kwargs(loop): + async def do_test(): + test_result = 5 + + @AsyncStatus.wrap + async def coro_status(x: int, y: int, *, z=False): + await asyncio.sleep(0.01) + nonlocal test_result + test_result = x * y if z else 0 + + st = coro_status(3, 4, z=True) + assert isinstance(st, AsyncStatus) + await asyncio.wait_for(st.task, None) + assert st.done + assert test_result == 12 + + loop.run_until_complete(do_test()) diff --git a/tests/core/test_device_collector.py b/tests/core/test_device_collector.py index cc187a6be8..35c71508a9 100644 --- a/tests/core/test_device_collector.py +++ b/tests/core/test_device_collector.py @@ -2,7 +2,8 @@ import pytest from bluesky import plan_stubs as bps -from bluesky.run_engine import RunEngine, TransitionError +from bluesky.run_engine import RunEngine +from super_state_machine.errors import TransitionError from ophyd_async.core import DEFAULT_TIMEOUT, Device, DeviceCollector, NotConnected from ophyd_async.epics.motion import motor diff --git a/tests/core/test_watchable_async_status.py b/tests/core/test_watchable_async_status.py new file mode 100644 index 0000000000..721f191467 --- /dev/null +++ b/tests/core/test_watchable_async_status.py @@ -0,0 +1,200 @@ +import asyncio +from functools import partial +from typing import AsyncIterator + +import bluesky.plan_stubs as bps +import pytest +from bluesky.protocols import Movable + +from ophyd_async.core.async_status import AsyncStatus, WatchableAsyncStatus +from ophyd_async.core.signal import soft_signal_r_and_setter +from ophyd_async.core.standard_readable import StandardReadable +from ophyd_async.core.utils import WatcherUpdate + + +class SetFailed(Exception): + pass + + +def watcher_test( + storage: list[WatcherUpdate], + *, + name: str | None, + current: int | None, + initial: int | None, + target: int | None, + unit: str | None, + precision: float | None, + fraction: float | None, + time_elapsed: float | None, + time_remaining: float | None, +): + storage.append( + WatcherUpdate( + name=name, + current=current, + initial=initial, + target=target, + unit=unit, + precision=precision, + fraction=fraction, + time_elapsed=time_elapsed, + time_remaining=time_remaining, + ) + ) + + +class TWatcher: + updates: list[int] = [] + + def __call__( + self, + *, + name: str | None, + current: int | None, + initial: int | None, + target: int | None, + unit: str | None, + precision: float | None, + fraction: float | None, + time_elapsed: float | None, + time_remaining: float | None, + ) -> None: + self.updates.append(current or -1) + + +class ASTestDevice(StandardReadable, Movable): + def __init__(self, name: str = "") -> None: + self._staged: bool = False + self.sig, self._sig_setter = soft_signal_r_and_setter(datatype=int) + super().__init__(name) + + @AsyncStatus.wrap + async def stage(self): + self._staged = True + await asyncio.sleep(0.01) + + +class ASTestDeviceSingleSet(ASTestDevice): + @AsyncStatus.wrap + async def set(self, val): + assert self._staged + await asyncio.sleep(0.01) + self._sig_setter(val) # type: ignore + + +class ASTestDeviceTimeoutSet(ASTestDevice): + @WatchableAsyncStatus.wrap + async def set(self, val, timeout=0.01): + assert self._staged + await asyncio.sleep(0.01) + self._sig_setter(val - 1) # type: ignore + await asyncio.sleep(0.01) + yield WatcherUpdate(1, 1, 1) + await asyncio.sleep(0.01) + yield WatcherUpdate(1, 1, 1) + + +class ASTestDeviceIteratorSet(ASTestDevice): + def __init__( + self, name: str = "", values=[1, 2, 3, 4, 5], complete_set: bool = True + ) -> None: + self.values = values + self.complete_set = complete_set + super().__init__(name) + + @WatchableAsyncStatus.wrap + async def set(self, val) -> AsyncIterator: + assert self._staged + self._initial = await self.sig.get_value() + for point in self.values: + await asyncio.sleep(0.01) + yield WatcherUpdate( + name=self.name, + current=point, + initial=self._initial, + target=val, + unit="dimensionless", + precision=0.0, + time_elapsed=0, + time_remaining=0, + fraction=0, + ) + if self.complete_set: + self._sig_setter(val) # type: ignore + yield WatcherUpdate( + name=self.name, + current=val, + initial=self._initial, + target=val, + unit="dimensionless", + precision=0.0, + time_elapsed=0, + time_remaining=0, + fraction=0, + ) + else: + raise SetFailed + return + + +async def test_asyncstatus_wraps_both_stage_and_set(RE): + td = ASTestDeviceSingleSet() + await td.connect() + with pytest.raises(AssertionError): + st = td.set(5) + assert isinstance(st, AsyncStatus) + await st + await td.stage() + st = td.set(5) + assert isinstance(st, AsyncStatus) + await st + assert (await td.sig.get_value()) == 5 + RE(bps.abs_set(td, 3, wait=True)) + assert (await td.sig.get_value()) == 3 + + +async def test_asyncstatus_wraps_set_iterator_with_class_or_func_watcher(RE): + td = ASTestDeviceIteratorSet() + await td.connect() + await td.stage() + st = td.set(6) + updates = [] + + w = TWatcher() + st.watch(partial(watcher_test, updates)) + st.watch(w) + await st + assert st.done + assert st.success + assert len(updates) == 6 + assert sum(w.updates) == 21 + + +async def test_watchableasyncstatus_wraps_failing_set_iterator(RE): + td = ASTestDeviceIteratorSet(values=[1, 2, 3], complete_set=False) + await td.connect() + await td.stage() + st = td.set(6) + updates = [] + + st.watch(partial(watcher_test, updates)) + try: + await st + except Exception: + ... + assert st.done + assert not st.success + assert isinstance(st.exception(), SetFailed) + assert len(updates) == 3 + + +async def test_watchableasyncstatus_times_out(RE): + td = ASTestDeviceTimeoutSet() + await td.connect() + await td.stage() + st = td.set(6, timeout=0.01) + while not st.done: + await asyncio.sleep(0.01) + assert not st.success + assert isinstance(st.exception(), asyncio.TimeoutError) diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index fa76d34892..805e41d0ad 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -68,13 +68,36 @@ async def test_mover_stopped(mock_mover: demo.Mover): assert callbacks == [None] -class Watcher: +class DemoWatcher: def __init__(self) -> None: self._event = asyncio.Event() self._mock = Mock() - def __call__(self, *args, **kwargs): - self._mock(*args, **kwargs) + def __call__( + self, + *args, + current: float, + initial: float, + target: float, + name: str | None = None, + unit: str | None = None, + precision: float | None = None, + fraction: float | None = None, + time_elapsed: float | None = None, + time_remaining: float | None = None, + **kwargs, + ): + self._mock( + *args, + current=current, + initial=initial, + target=target, + name=name, + unit=unit, + precision=precision, + time_elapsed=time_elapsed, + **kwargs, + ) self._event.set() async def wait_for_call(self, *args, **kwargs): @@ -87,7 +110,7 @@ async def wait_for_call(self, *args, **kwargs): async def test_mover_moving_well(mock_mover: demo.Mover) -> None: s = mock_mover.set(0.55) - watcher = Watcher() + watcher = DemoWatcher() s.watch(watcher) done = Mock() s.add_callback(done) @@ -98,7 +121,7 @@ async def test_mover_moving_well(mock_mover: demo.Mover) -> None: target=0.55, unit="mm", precision=3, - time_elapsed=pytest.approx(0.0, abs=0.05), + time_elapsed=ANY, # Test is flaky in slow CI ) await assert_value(mock_mover.setpoint, 0.55) @@ -113,7 +136,7 @@ async def test_mover_moving_well(mock_mover: demo.Mover) -> None: target=0.55, unit="mm", precision=3, - time_elapsed=pytest.approx(0.1, abs=0.05), + time_elapsed=ANY, # Test is flaky in slow CI ) set_mock_value(mock_mover.readback, 0.5499999) await asyncio.sleep(A_WHILE) @@ -248,18 +271,6 @@ async def test_assembly_renaming() -> None: assert thing.x.stop_.name == "foo-x-stop" -def test_mover_in_re(mock_mover: demo.Mover, RE) -> None: - mock_mover.move(0) - - def my_plan(): - mock_mover.move(0) - return - yield - - with pytest.raises(RuntimeError, match="Will deadlock run engine if run in a plan"): - RE(my_plan()) - - async def test_dynamic_sensor_group_disconnected(): with pytest.raises(NotConnected): async with DeviceCollector(timeout=0.1): diff --git a/tests/epics/motion/test_motor.py b/tests/epics/motion/test_motor.py index 0654661bb9..63343d05e8 100644 --- a/tests/epics/motion/test_motor.py +++ b/tests/epics/motion/test_motor.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import Dict from unittest.mock import Mock, call @@ -28,7 +29,56 @@ async def sim_motor(): yield sim_motor +async def wait_for_eq(item, attribute, comparison, timeout): + timeout_time = time.monotonic() + timeout + while getattr(item, attribute) != comparison: + await asyncio.sleep(A_BIT) + if time.monotonic() > timeout_time: + raise TimeoutError + + async def test_motor_moving_well(sim_motor: motor.Motor) -> None: + set_mock_put_proceeds(sim_motor.user_setpoint, False) + s = sim_motor.set(0.55) + watcher = Mock() + s.watch(watcher) + done = Mock() + s.add_callback(done) + await wait_for_eq(watcher, "call_count", 1, 1) + assert watcher.call_args == call( + name="sim_motor", + current=0.0, + initial=0.0, + target=0.55, + unit="mm", + precision=3, + time_elapsed=pytest.approx(0.0, abs=0.05), + ) + watcher.reset_mock() + assert 0.55 == await sim_motor.user_setpoint.get_value() + assert not s.done + await asyncio.sleep(0.1) + set_mock_value(sim_motor.user_readback, 0.1) + await wait_for_eq(watcher, "call_count", 1, 1) + assert watcher.call_count == 1 + assert watcher.call_args == call( + name="sim_motor", + current=0.1, + initial=0.0, + target=0.55, + unit="mm", + precision=3, + time_elapsed=pytest.approx(0.1, abs=0.05), + ) + set_mock_value(sim_motor.motor_done_move, True) + set_mock_value(sim_motor.user_readback, 0.55) + set_mock_put_proceeds(sim_motor.user_setpoint, True) + await asyncio.sleep(A_BIT) + await wait_for_eq(s, "done", True, 1) + done.assert_called_once_with(s) + + +async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: set_mock_put_proceeds(sim_motor.user_setpoint, False) s = sim_motor.set(0.55) watcher = Mock() @@ -51,6 +101,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: assert not s.done await asyncio.sleep(0.1) set_mock_value(sim_motor.user_readback, 0.1) + await asyncio.sleep(0.1) assert watcher.call_count == 1 assert watcher.call_args == call( name="sim_motor", @@ -59,7 +110,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: target=0.55, unit="mm", precision=3, - time_elapsed=pytest.approx(0.1, abs=0.05), + time_elapsed=pytest.approx(0.11, abs=0.05), ) set_mock_put_proceeds(sim_motor.user_setpoint, True) await asyncio.sleep(A_BIT) @@ -68,6 +119,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: async def test_motor_moving_stopped(sim_motor: motor.Motor): + set_mock_value(sim_motor.motor_done_move, False) set_mock_put_proceeds(sim_motor.user_setpoint, False) s = sim_motor.set(1.5) s.add_callback(Mock()) @@ -107,15 +159,3 @@ async def test_set_velocity(sim_motor: motor.Motor) -> None: await v.set(3.0) assert (await v.read())["sim_motor-velocity"]["value"] == 3.0 assert q.empty() - - -def test_motor_in_re(sim_motor: motor.Motor, RE) -> None: - sim_motor.move(0) - - def my_plan(): - sim_motor.move(0) - return - yield - - with pytest.raises(RuntimeError, match="Will deadlock run engine if run in a plan"): - RE(my_plan()) diff --git a/tests/sim/demo/test_sim_motor.py b/tests/sim/demo/test_sim_motor.py index 7339a5d8f5..7c6ec074c0 100644 --- a/tests/sim/demo/test_sim_motor.py +++ b/tests/sim/demo/test_sim_motor.py @@ -4,6 +4,7 @@ from bluesky.plans import spiral_square from bluesky.run_engine import RunEngine +from ophyd_async.core.async_status import AsyncStatusBase from ophyd_async.core.device import DeviceCollector from ophyd_async.sim.demo.sim_motor import SimMotor @@ -12,8 +13,8 @@ async def test_move_sim_in_plan(): RE = RunEngine() async with DeviceCollector(): - m1 = SimMotor("M1", "sim_motor1") - m2 = SimMotor("M2", "sim_motor2") + m1 = SimMotor("M1") + m2 = SimMotor("M2") my_plan = spiral_square([], m1, m2, 0, 0, 4, 4, 10, 10) @@ -44,9 +45,12 @@ async def test_stop(): # this move should take 10 seconds but we will stop it after 0.2 move_status = m1.set(10) + while not m1._move_status: + # wait to actually get the move started + await asyncio.sleep(0) await asyncio.sleep(0.2) m1.stop() - + await asyncio.sleep(0) new_pos = await m1.user_readback.get_value() assert move_status.done @@ -70,16 +74,16 @@ async def test_timeout(): move_status = m1.set(10, timeout=0.1) await asyncio.sleep(0.2) - # verify status of inner task set up to run _move.update_position() - assert isinstance(m1._move_task, asyncio.Task) - assert m1._move_task.done - assert m1._move_task.cancelled - - # verify status of outer task set up to run _move() - assert move_status.task.done + # check inner status + assert m1._move_status is not None + assert m1._move_status.done + assert not move_status.success assert move_status.task.cancelled - new_pos = await m1.user_readback.get_value() - assert new_pos < 10 + # check outer status + assert isinstance(move_status, AsyncStatusBase) assert move_status.done assert not move_status.success + assert move_status.task.cancelled + new_pos = await m1.user_readback.get_value() + assert new_pos < 10